From 68747de9b41d65962bf60fd36038cd42b23bc60b Mon Sep 17 00:00:00 2001 From: chyezh Date: Fri, 28 Jun 2024 11:11:15 +0800 Subject: [PATCH] enhance: implement balancer at streaming coord - add balancer implementation - add channel count fair balance policy - add discover grpc service Signed-off-by: chyezh --- internal/metastore/catalog.go | 12 + .../metastore/kv/streamingcoord/constant.go | 6 + .../metastore/kv/streamingcoord/kv_catalog.go | 62 ++ .../kv/streamingcoord/kv_catalog_test.go | 66 ++ .../mock_StreamingCoordCataLog.go | 135 +++ ...ignmentService_AssignmentDiscoverServer.go | 378 ++++++++ .../server/mock_balancer/mock_Balancer.go | 199 +++++ .../client/mock_manager/mock_ManagerClient.go | 256 ++++++ internal/proto/streaming.proto | 212 ++++- .../server/balancer/balance_timer.go | 53 ++ .../server/balancer/balancer.go | 28 + .../server/balancer/balancer_impl.go | 277 ++++++ .../server/balancer/balancer_test.go | 115 +++ .../server/balancer/channel/manager.go | 223 +++++ .../server/balancer/channel/manager_test.go | 143 ++++ .../server/balancer/channel/pchannel.go | 150 ++++ .../server/balancer/channel/pchannel_test.go | 107 +++ .../server/balancer/policy/init.go | 7 + .../balancer/policy/pchannel_count_fair.go | 69 ++ .../policy/pchannel_count_fair_test.go | 183 ++++ .../server/balancer/policy_registry.go | 65 ++ .../streamingcoord/server/balancer/request.go | 42 + .../server/resource/resource.go | 66 ++ .../server/resource/resource_test.go | 32 + .../server/resource/test_utility.go | 12 + .../server/service/assignment.go | 37 + .../discover/discover_grpc_server_helper.go | 51 ++ .../service/discover/discover_server.go | 98 +++ .../service/discover/discover_server_test.go | 82 ++ .../streamingnode/client/manager/manager.go | 25 + internal/streamingservice/.mockery.yaml | 12 +- .../typeconverter/streaming_node.go | 20 + internal/util/streamingutil/util/topic.go | 30 + .../util/streamingutil/util/topic_test.go | 19 + .../.mockery.yaml => .mockery_pkg.yaml} | 3 + pkg/Makefile | 8 +- pkg/metrics/streaming_service_metrics.go | 6 +- pkg/mocks/mock_kv/mock_MetaKv.go | 807 ++++++++++++++++++ pkg/mq/msgdispatcher/mock_client.go | 13 +- pkg/mq/msgstream/mock_msgstream.go | 11 +- pkg/streaming/util/types/pchannel_info.go | 5 + pkg/streaming/util/types/streaming_node.go | 42 + .../util/types/streaming_node_test.go | 15 + .../walimpls/helper/scanner_helper.go | 32 +- pkg/util/paramtable/component_param.go | 64 +- pkg/util/paramtable/component_param_test.go | 12 + pkg/util/paramtable/param_item.go | 12 + pkg/util/syncutil/async_task_notifier.go | 50 ++ pkg/util/syncutil/async_task_notifier_test.go | 57 ++ pkg/util/typeutil/version.go | 56 ++ pkg/util/typeutil/version_test.go | 29 + 51 files changed, 4405 insertions(+), 89 deletions(-) create mode 100644 internal/metastore/kv/streamingcoord/constant.go create mode 100644 internal/metastore/kv/streamingcoord/kv_catalog.go create mode 100644 internal/metastore/kv/streamingcoord/kv_catalog_test.go create mode 100644 internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go create mode 100644 internal/mocks/proto/mock_streamingpb/mock_StreamingCoordAssignmentService_AssignmentDiscoverServer.go create mode 100644 internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go create mode 100644 internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go create mode 100644 internal/streamingcoord/server/balancer/balance_timer.go create mode 100644 internal/streamingcoord/server/balancer/balancer.go create mode 100644 internal/streamingcoord/server/balancer/balancer_impl.go create mode 100644 internal/streamingcoord/server/balancer/balancer_test.go create mode 100644 internal/streamingcoord/server/balancer/channel/manager.go create mode 100644 internal/streamingcoord/server/balancer/channel/manager_test.go create mode 100644 internal/streamingcoord/server/balancer/channel/pchannel.go create mode 100644 internal/streamingcoord/server/balancer/channel/pchannel_test.go create mode 100644 internal/streamingcoord/server/balancer/policy/init.go create mode 100644 internal/streamingcoord/server/balancer/policy/pchannel_count_fair.go create mode 100644 internal/streamingcoord/server/balancer/policy/pchannel_count_fair_test.go create mode 100644 internal/streamingcoord/server/balancer/policy_registry.go create mode 100644 internal/streamingcoord/server/balancer/request.go create mode 100644 internal/streamingcoord/server/resource/resource.go create mode 100644 internal/streamingcoord/server/resource/resource_test.go create mode 100644 internal/streamingcoord/server/resource/test_utility.go create mode 100644 internal/streamingcoord/server/service/assignment.go create mode 100644 internal/streamingcoord/server/service/discover/discover_grpc_server_helper.go create mode 100644 internal/streamingcoord/server/service/discover/discover_server.go create mode 100644 internal/streamingcoord/server/service/discover/discover_server_test.go create mode 100644 internal/streamingnode/client/manager/manager.go create mode 100644 internal/util/streamingutil/typeconverter/streaming_node.go create mode 100644 internal/util/streamingutil/util/topic.go create mode 100644 internal/util/streamingutil/util/topic_test.go rename pkg/{streaming/.mockery.yaml => .mockery_pkg.yaml} (90%) create mode 100644 pkg/mocks/mock_kv/mock_MetaKv.go create mode 100644 pkg/streaming/util/types/streaming_node.go create mode 100644 pkg/streaming/util/types/streaming_node_test.go create mode 100644 pkg/util/syncutil/async_task_notifier.go create mode 100644 pkg/util/syncutil/async_task_notifier_test.go create mode 100644 pkg/util/typeutil/version.go create mode 100644 pkg/util/typeutil/version_test.go diff --git a/internal/metastore/catalog.go b/internal/metastore/catalog.go index 996805253e495..1e3e1cf5c7cb3 100644 --- a/internal/metastore/catalog.go +++ b/internal/metastore/catalog.go @@ -9,6 +9,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/streamingpb" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -186,3 +187,14 @@ type QueryCoordCatalog interface { RemoveCollectionTarget(collectionID int64) error GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) } + +// StreamingCoordCataLog is the interface for streamingcoord catalog +type StreamingCoordCataLog interface { + // physical channel watch related + + // ListPChannel list all pchannels on milvus. + ListPChannel(ctx context.Context) ([]*streamingpb.PChannelMeta, error) + + // SavePChannel save a pchannel info to metastore. + SavePChannels(ctx context.Context, info []*streamingpb.PChannelMeta) error +} diff --git a/internal/metastore/kv/streamingcoord/constant.go b/internal/metastore/kv/streamingcoord/constant.go new file mode 100644 index 0000000000000..0603aeda4d8d7 --- /dev/null +++ b/internal/metastore/kv/streamingcoord/constant.go @@ -0,0 +1,6 @@ +package streamingcoord + +const ( + MetaPrefix = "streamingcoord-meta" + PChannelMeta = MetaPrefix + "/pchannel-meta" +) diff --git a/internal/metastore/kv/streamingcoord/kv_catalog.go b/internal/metastore/kv/streamingcoord/kv_catalog.go new file mode 100644 index 0000000000000..a607b9805270c --- /dev/null +++ b/internal/metastore/kv/streamingcoord/kv_catalog.go @@ -0,0 +1,62 @@ +package streamingcoord + +import ( + "context" + + "github.com/cockroachdb/errors" + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus/internal/metastore" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/kv" +) + +// NewCataLog creates a new catalog instance +func NewCataLog(metaKV kv.MetaKv) metastore.StreamingCoordCataLog { + return &catalog{ + metaKV: metaKV, + } +} + +// catalog is a kv based catalog. +type catalog struct { + metaKV kv.MetaKv +} + +// ListPChannels returns all pchannels +func (c *catalog) ListPChannel(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { + keys, values, err := c.metaKV.LoadWithPrefix(PChannelMeta) + if err != nil { + return nil, err + } + + infos := make([]*streamingpb.PChannelMeta, 0, len(values)) + for k, value := range values { + info := &streamingpb.PChannelMeta{} + err = proto.Unmarshal([]byte(value), info) + if err != nil { + return nil, errors.Wrapf(err, "unmarshal pchannel %s failed", keys[k]) + } + infos = append(infos, info) + } + return infos, nil +} + +// SavePChannels saves a pchannel +func (c *catalog) SavePChannels(ctx context.Context, infos []*streamingpb.PChannelMeta) error { + kvs := make(map[string]string, len(infos)) + for _, info := range infos { + key := buildPChannelInfoPath(info.GetChannel().GetName()) + v, err := proto.Marshal(info) + if err != nil { + return errors.Wrapf(err, "marshal pchannel %s failed", info.GetChannel().GetName()) + } + kvs[key] = string(v) + } + return c.metaKV.MultiSave(kvs) +} + +// buildPChannelInfoPath builds the path for pchannel info. +func buildPChannelInfoPath(name string) string { + return PChannelMeta + "/" + name +} diff --git a/internal/metastore/kv/streamingcoord/kv_catalog_test.go b/internal/metastore/kv/streamingcoord/kv_catalog_test.go new file mode 100644 index 0000000000000..60432533ef735 --- /dev/null +++ b/internal/metastore/kv/streamingcoord/kv_catalog_test.go @@ -0,0 +1,66 @@ +package streamingcoord + +import ( + "context" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/mocks/mock_kv" +) + +func TestCatalog(t *testing.T) { + kv := mock_kv.NewMockMetaKv(t) + + kvStorage := make(map[string]string) + kv.EXPECT().LoadWithPrefix(mock.Anything).RunAndReturn(func(s string) ([]string, []string, error) { + keys := make([]string, 0, len(kvStorage)) + vals := make([]string, 0, len(kvStorage)) + for k, v := range kvStorage { + keys = append(keys, k) + vals = append(vals, v) + } + return keys, vals, nil + }) + kv.EXPECT().MultiSave(mock.Anything).RunAndReturn(func(kvs map[string]string) error { + for k, v := range kvs { + kvStorage[k] = v + } + return nil + }) + + catalog := NewCataLog(kv) + metas, err := catalog.ListPChannel(context.Background()) + assert.NoError(t, err) + assert.Empty(t, metas) + + err = catalog.SavePChannels(context.Background(), []*streamingpb.PChannelMeta{ + { + Channel: &streamingpb.PChannelInfo{Name: "test", Term: 1}, + Node: &streamingpb.StreamingNodeInfo{ServerId: 1}, + }, + { + Channel: &streamingpb.PChannelInfo{Name: "test2", Term: 1}, + Node: &streamingpb.StreamingNodeInfo{ServerId: 1}, + }, + }) + assert.NoError(t, err) + + metas, err = catalog.ListPChannel(context.Background()) + assert.NoError(t, err) + assert.Len(t, metas, 2) + + // error path. + kv.EXPECT().LoadWithPrefix(mock.Anything).Unset() + kv.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, errors.New("load error")) + metas, err = catalog.ListPChannel(context.Background()) + assert.Error(t, err) + assert.Nil(t, metas) + + kv.EXPECT().MultiSave(mock.Anything).Unset() + kv.EXPECT().MultiSave(mock.Anything).Return(errors.New("save error")) + assert.Error(t, err) +} diff --git a/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go b/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go new file mode 100644 index 0000000000000..473652f2af141 --- /dev/null +++ b/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go @@ -0,0 +1,135 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_metastore + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + streamingpb "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +// MockStreamingCoordCataLog is an autogenerated mock type for the StreamingCoordCataLog type +type MockStreamingCoordCataLog struct { + mock.Mock +} + +type MockStreamingCoordCataLog_Expecter struct { + mock *mock.Mock +} + +func (_m *MockStreamingCoordCataLog) EXPECT() *MockStreamingCoordCataLog_Expecter { + return &MockStreamingCoordCataLog_Expecter{mock: &_m.Mock} +} + +// ListPChannel provides a mock function with given fields: ctx +func (_m *MockStreamingCoordCataLog) ListPChannel(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { + ret := _m.Called(ctx) + + var r0 []*streamingpb.PChannelMeta + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*streamingpb.PChannelMeta, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*streamingpb.PChannelMeta); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*streamingpb.PChannelMeta) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockStreamingCoordCataLog_ListPChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListPChannel' +type MockStreamingCoordCataLog_ListPChannel_Call struct { + *mock.Call +} + +// ListPChannel is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockStreamingCoordCataLog_Expecter) ListPChannel(ctx interface{}) *MockStreamingCoordCataLog_ListPChannel_Call { + return &MockStreamingCoordCataLog_ListPChannel_Call{Call: _e.mock.On("ListPChannel", ctx)} +} + +func (_c *MockStreamingCoordCataLog_ListPChannel_Call) Run(run func(ctx context.Context)) *MockStreamingCoordCataLog_ListPChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockStreamingCoordCataLog_ListPChannel_Call) Return(_a0 []*streamingpb.PChannelMeta, _a1 error) *MockStreamingCoordCataLog_ListPChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockStreamingCoordCataLog_ListPChannel_Call) RunAndReturn(run func(context.Context) ([]*streamingpb.PChannelMeta, error)) *MockStreamingCoordCataLog_ListPChannel_Call { + _c.Call.Return(run) + return _c +} + +// SavePChannels provides a mock function with given fields: ctx, info +func (_m *MockStreamingCoordCataLog) SavePChannels(ctx context.Context, info []*streamingpb.PChannelMeta) error { + ret := _m.Called(ctx, info) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []*streamingpb.PChannelMeta) error); ok { + r0 = rf(ctx, info) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingCoordCataLog_SavePChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SavePChannels' +type MockStreamingCoordCataLog_SavePChannels_Call struct { + *mock.Call +} + +// SavePChannels is a helper method to define mock.On call +// - ctx context.Context +// - info []*streamingpb.PChannelMeta +func (_e *MockStreamingCoordCataLog_Expecter) SavePChannels(ctx interface{}, info interface{}) *MockStreamingCoordCataLog_SavePChannels_Call { + return &MockStreamingCoordCataLog_SavePChannels_Call{Call: _e.mock.On("SavePChannels", ctx, info)} +} + +func (_c *MockStreamingCoordCataLog_SavePChannels_Call) Run(run func(ctx context.Context, info []*streamingpb.PChannelMeta)) *MockStreamingCoordCataLog_SavePChannels_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]*streamingpb.PChannelMeta)) + }) + return _c +} + +func (_c *MockStreamingCoordCataLog_SavePChannels_Call) Return(_a0 error) *MockStreamingCoordCataLog_SavePChannels_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingCoordCataLog_SavePChannels_Call) RunAndReturn(run func(context.Context, []*streamingpb.PChannelMeta) error) *MockStreamingCoordCataLog_SavePChannels_Call { + _c.Call.Return(run) + return _c +} + +// NewMockStreamingCoordCataLog creates a new instance of MockStreamingCoordCataLog. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockStreamingCoordCataLog(t interface { + mock.TestingT + Cleanup(func()) +}) *MockStreamingCoordCataLog { + mock := &MockStreamingCoordCataLog{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/proto/mock_streamingpb/mock_StreamingCoordAssignmentService_AssignmentDiscoverServer.go b/internal/mocks/proto/mock_streamingpb/mock_StreamingCoordAssignmentService_AssignmentDiscoverServer.go new file mode 100644 index 0000000000000..efcfea048991e --- /dev/null +++ b/internal/mocks/proto/mock_streamingpb/mock_StreamingCoordAssignmentService_AssignmentDiscoverServer.go @@ -0,0 +1,378 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_streamingpb + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + metadata "google.golang.org/grpc/metadata" + + streamingpb "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +// MockStreamingCoordAssignmentService_AssignmentDiscoverServer is an autogenerated mock type for the StreamingCoordAssignmentService_AssignmentDiscoverServer type +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer struct { + mock.Mock +} + +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockStreamingCoordAssignmentService_AssignmentDiscoverServer) EXPECT() *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter { + return &MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter{mock: &_m.Mock} +} + +// Context provides a mock function with given fields: +func (_m *MockStreamingCoordAssignmentService_AssignmentDiscoverServer) Context() context.Context { + ret := _m.Called() + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context' +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call struct { + *mock.Call +} + +// Context is a helper method to define mock.On call +func (_e *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter) Context() *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call { + return &MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call{Call: _e.mock.On("Context")} +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call) Run(run func()) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call) Return(_a0 context.Context) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call) RunAndReturn(run func() context.Context) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Context_Call { + _c.Call.Return(run) + return _c +} + +// Recv provides a mock function with given fields: +func (_m *MockStreamingCoordAssignmentService_AssignmentDiscoverServer) Recv() (*streamingpb.AssignmentDiscoverRequest, error) { + ret := _m.Called() + + var r0 *streamingpb.AssignmentDiscoverRequest + var r1 error + if rf, ok := ret.Get(0).(func() (*streamingpb.AssignmentDiscoverRequest, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *streamingpb.AssignmentDiscoverRequest); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*streamingpb.AssignmentDiscoverRequest) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Recv' +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call struct { + *mock.Call +} + +// Recv is a helper method to define mock.On call +func (_e *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter) Recv() *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call { + return &MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call{Call: _e.mock.On("Recv")} +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call) Run(run func()) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call) Return(_a0 *streamingpb.AssignmentDiscoverRequest, _a1 error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call) RunAndReturn(run func() (*streamingpb.AssignmentDiscoverRequest, error)) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Recv_Call { + _c.Call.Return(run) + return _c +} + +// RecvMsg provides a mock function with given fields: m +func (_m *MockStreamingCoordAssignmentService_AssignmentDiscoverServer) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg' +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call struct { + *mock.Call +} + +// RecvMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter) RecvMsg(m interface{}) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call { + return &MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)} +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call) Run(run func(m interface{})) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call) Return(_a0 error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call) RunAndReturn(run func(interface{}) error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_RecvMsg_Call { + _c.Call.Return(run) + return _c +} + +// Send provides a mock function with given fields: _a0 +func (_m *MockStreamingCoordAssignmentService_AssignmentDiscoverServer) Send(_a0 *streamingpb.AssignmentDiscoverResponse) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(*streamingpb.AssignmentDiscoverResponse) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send' +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call struct { + *mock.Call +} + +// Send is a helper method to define mock.On call +// - _a0 *streamingpb.AssignmentDiscoverResponse +func (_e *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter) Send(_a0 interface{}) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call { + return &MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call{Call: _e.mock.On("Send", _a0)} +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call) Run(run func(_a0 *streamingpb.AssignmentDiscoverResponse)) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*streamingpb.AssignmentDiscoverResponse)) + }) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call) Return(_a0 error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call) RunAndReturn(run func(*streamingpb.AssignmentDiscoverResponse) error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Send_Call { + _c.Call.Return(run) + return _c +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *MockStreamingCoordAssignmentService_AssignmentDiscoverServer) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendHeader' +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call struct { + *mock.Call +} + +// SendHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter) SendHeader(_a0 interface{}) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call { + return &MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call{Call: _e.mock.On("SendHeader", _a0)} +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call) Run(run func(_a0 metadata.MD)) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call) Return(_a0 error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendHeader_Call { + _c.Call.Return(run) + return _c +} + +// SendMsg provides a mock function with given fields: m +func (_m *MockStreamingCoordAssignmentService_AssignmentDiscoverServer) SendMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg' +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call struct { + *mock.Call +} + +// SendMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter) SendMsg(m interface{}) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call { + return &MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call{Call: _e.mock.On("SendMsg", m)} +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call) Run(run func(m interface{})) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call) Return(_a0 error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call) RunAndReturn(run func(interface{}) error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SendMsg_Call { + _c.Call.Return(run) + return _c +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *MockStreamingCoordAssignmentService_AssignmentDiscoverServer) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHeader' +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call struct { + *mock.Call +} + +// SetHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter) SetHeader(_a0 interface{}) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call { + return &MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call{Call: _e.mock.On("SetHeader", _a0)} +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call) Run(run func(_a0 metadata.MD)) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call) Return(_a0 error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetHeader_Call { + _c.Call.Return(run) + return _c +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *MockStreamingCoordAssignmentService_AssignmentDiscoverServer) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} + +// MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTrailer' +type MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call struct { + *mock.Call +} + +// SetTrailer is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_Expecter) SetTrailer(_a0 interface{}) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call { + return &MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call{Call: _e.mock.On("SetTrailer", _a0)} +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call) Run(run func(_a0 metadata.MD)) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call) Return() *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call { + _c.Call.Return() + return _c +} + +func (_c *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call) RunAndReturn(run func(metadata.MD)) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer_SetTrailer_Call { + _c.Call.Return(run) + return _c +} + +// NewMockStreamingCoordAssignmentService_AssignmentDiscoverServer creates a new instance of MockStreamingCoordAssignmentService_AssignmentDiscoverServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockStreamingCoordAssignmentService_AssignmentDiscoverServer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockStreamingCoordAssignmentService_AssignmentDiscoverServer { + mock := &MockStreamingCoordAssignmentService_AssignmentDiscoverServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go b/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go new file mode 100644 index 0000000000000..f764688f9d08c --- /dev/null +++ b/internal/mocks/streamingcoord/server/mock_balancer/mock_Balancer.go @@ -0,0 +1,199 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_balancer + +import ( + context "context" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" + mock "github.com/stretchr/testify/mock" + + typeutil "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// MockBalancer is an autogenerated mock type for the Balancer type +type MockBalancer struct { + mock.Mock +} + +type MockBalancer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockBalancer) EXPECT() *MockBalancer_Expecter { + return &MockBalancer_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockBalancer) Close() { + _m.Called() +} + +// MockBalancer_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockBalancer_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockBalancer_Expecter) Close() *MockBalancer_Close_Call { + return &MockBalancer_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockBalancer_Close_Call) Run(run func()) *MockBalancer_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBalancer_Close_Call) Return() *MockBalancer_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockBalancer_Close_Call) RunAndReturn(run func()) *MockBalancer_Close_Call { + _c.Call.Return(run) + return _c +} + +// MarkAsUnavailable provides a mock function with given fields: ctx, pChannels +func (_m *MockBalancer) MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error { + ret := _m.Called(ctx, pChannels) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []types.PChannelInfo) error); ok { + r0 = rf(ctx, pChannels) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBalancer_MarkAsUnavailable_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MarkAsUnavailable' +type MockBalancer_MarkAsUnavailable_Call struct { + *mock.Call +} + +// MarkAsUnavailable is a helper method to define mock.On call +// - ctx context.Context +// - pChannels []types.PChannelInfo +func (_e *MockBalancer_Expecter) MarkAsUnavailable(ctx interface{}, pChannels interface{}) *MockBalancer_MarkAsUnavailable_Call { + return &MockBalancer_MarkAsUnavailable_Call{Call: _e.mock.On("MarkAsUnavailable", ctx, pChannels)} +} + +func (_c *MockBalancer_MarkAsUnavailable_Call) Run(run func(ctx context.Context, pChannels []types.PChannelInfo)) *MockBalancer_MarkAsUnavailable_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]types.PChannelInfo)) + }) + return _c +} + +func (_c *MockBalancer_MarkAsUnavailable_Call) Return(_a0 error) *MockBalancer_MarkAsUnavailable_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBalancer_MarkAsUnavailable_Call) RunAndReturn(run func(context.Context, []types.PChannelInfo) error) *MockBalancer_MarkAsUnavailable_Call { + _c.Call.Return(run) + return _c +} + +// Trigger provides a mock function with given fields: ctx +func (_m *MockBalancer) Trigger(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBalancer_Trigger_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Trigger' +type MockBalancer_Trigger_Call struct { + *mock.Call +} + +// Trigger is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockBalancer_Expecter) Trigger(ctx interface{}) *MockBalancer_Trigger_Call { + return &MockBalancer_Trigger_Call{Call: _e.mock.On("Trigger", ctx)} +} + +func (_c *MockBalancer_Trigger_Call) Run(run func(ctx context.Context)) *MockBalancer_Trigger_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockBalancer_Trigger_Call) Return(_a0 error) *MockBalancer_Trigger_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBalancer_Trigger_Call) RunAndReturn(run func(context.Context) error) *MockBalancer_Trigger_Call { + _c.Call.Return(run) + return _c +} + +// WatchBalanceResult provides a mock function with given fields: ctx, cb +func (_m *MockBalancer) WatchBalanceResult(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { + ret := _m.Called(ctx, cb) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error); ok { + r0 = rf(ctx, cb) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBalancer_WatchBalanceResult_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchBalanceResult' +type MockBalancer_WatchBalanceResult_Call struct { + *mock.Call +} + +// WatchBalanceResult is a helper method to define mock.On call +// - ctx context.Context +// - cb func(typeutil.VersionInt64Pair , []types.PChannelInfoAssigned) error +func (_e *MockBalancer_Expecter) WatchBalanceResult(ctx interface{}, cb interface{}) *MockBalancer_WatchBalanceResult_Call { + return &MockBalancer_WatchBalanceResult_Call{Call: _e.mock.On("WatchBalanceResult", ctx, cb)} +} + +func (_c *MockBalancer_WatchBalanceResult_Call) Run(run func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error)) *MockBalancer_WatchBalanceResult_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error)) + }) + return _c +} + +func (_c *MockBalancer_WatchBalanceResult_Call) Return(_a0 error) *MockBalancer_WatchBalanceResult_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBalancer_WatchBalanceResult_Call) RunAndReturn(run func(context.Context, func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error) *MockBalancer_WatchBalanceResult_Call { + _c.Call.Return(run) + return _c +} + +// NewMockBalancer creates a new instance of MockBalancer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockBalancer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockBalancer { + mock := &MockBalancer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go b/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go new file mode 100644 index 0000000000000..e5e69d7721081 --- /dev/null +++ b/internal/mocks/streamingnode/client/mock_manager/mock_ManagerClient.go @@ -0,0 +1,256 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_manager + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + sessionutil "github.com/milvus-io/milvus/internal/util/sessionutil" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +// MockManagerClient is an autogenerated mock type for the ManagerClient type +type MockManagerClient struct { + mock.Mock +} + +type MockManagerClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockManagerClient) EXPECT() *MockManagerClient_Expecter { + return &MockManagerClient_Expecter{mock: &_m.Mock} +} + +// Assign provides a mock function with given fields: ctx, pchannel +func (_m *MockManagerClient) Assign(ctx context.Context, pchannel types.PChannelInfoAssigned) error { + ret := _m.Called(ctx, pchannel) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.PChannelInfoAssigned) error); ok { + r0 = rf(ctx, pchannel) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockManagerClient_Assign_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Assign' +type MockManagerClient_Assign_Call struct { + *mock.Call +} + +// Assign is a helper method to define mock.On call +// - ctx context.Context +// - pchannel types.PChannelInfoAssigned +func (_e *MockManagerClient_Expecter) Assign(ctx interface{}, pchannel interface{}) *MockManagerClient_Assign_Call { + return &MockManagerClient_Assign_Call{Call: _e.mock.On("Assign", ctx, pchannel)} +} + +func (_c *MockManagerClient_Assign_Call) Run(run func(ctx context.Context, pchannel types.PChannelInfoAssigned)) *MockManagerClient_Assign_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PChannelInfoAssigned)) + }) + return _c +} + +func (_c *MockManagerClient_Assign_Call) Return(_a0 error) *MockManagerClient_Assign_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockManagerClient_Assign_Call) RunAndReturn(run func(context.Context, types.PChannelInfoAssigned) error) *MockManagerClient_Assign_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockManagerClient) Close() { + _m.Called() +} + +// MockManagerClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockManagerClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockManagerClient_Expecter) Close() *MockManagerClient_Close_Call { + return &MockManagerClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockManagerClient_Close_Call) Run(run func()) *MockManagerClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockManagerClient_Close_Call) Return() *MockManagerClient_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockManagerClient_Close_Call) RunAndReturn(run func()) *MockManagerClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// CollectAllStatus provides a mock function with given fields: ctx +func (_m *MockManagerClient) CollectAllStatus(ctx context.Context) (map[int64]types.StreamingNodeStatus, error) { + ret := _m.Called(ctx) + + var r0 map[int64]types.StreamingNodeStatus + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (map[int64]types.StreamingNodeStatus, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) map[int64]types.StreamingNodeStatus); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]types.StreamingNodeStatus) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockManagerClient_CollectAllStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CollectAllStatus' +type MockManagerClient_CollectAllStatus_Call struct { + *mock.Call +} + +// CollectAllStatus is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockManagerClient_Expecter) CollectAllStatus(ctx interface{}) *MockManagerClient_CollectAllStatus_Call { + return &MockManagerClient_CollectAllStatus_Call{Call: _e.mock.On("CollectAllStatus", ctx)} +} + +func (_c *MockManagerClient_CollectAllStatus_Call) Run(run func(ctx context.Context)) *MockManagerClient_CollectAllStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockManagerClient_CollectAllStatus_Call) Return(_a0 map[int64]types.StreamingNodeStatus, _a1 error) *MockManagerClient_CollectAllStatus_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockManagerClient_CollectAllStatus_Call) RunAndReturn(run func(context.Context) (map[int64]types.StreamingNodeStatus, error)) *MockManagerClient_CollectAllStatus_Call { + _c.Call.Return(run) + return _c +} + +// Remove provides a mock function with given fields: ctx, pchannel +func (_m *MockManagerClient) Remove(ctx context.Context, pchannel types.PChannelInfoAssigned) error { + ret := _m.Called(ctx, pchannel) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.PChannelInfoAssigned) error); ok { + r0 = rf(ctx, pchannel) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockManagerClient_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove' +type MockManagerClient_Remove_Call struct { + *mock.Call +} + +// Remove is a helper method to define mock.On call +// - ctx context.Context +// - pchannel types.PChannelInfoAssigned +func (_e *MockManagerClient_Expecter) Remove(ctx interface{}, pchannel interface{}) *MockManagerClient_Remove_Call { + return &MockManagerClient_Remove_Call{Call: _e.mock.On("Remove", ctx, pchannel)} +} + +func (_c *MockManagerClient_Remove_Call) Run(run func(ctx context.Context, pchannel types.PChannelInfoAssigned)) *MockManagerClient_Remove_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PChannelInfoAssigned)) + }) + return _c +} + +func (_c *MockManagerClient_Remove_Call) Return(_a0 error) *MockManagerClient_Remove_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockManagerClient_Remove_Call) RunAndReturn(run func(context.Context, types.PChannelInfoAssigned) error) *MockManagerClient_Remove_Call { + _c.Call.Return(run) + return _c +} + +// WatchNodeChanged provides a mock function with given fields: ctx +func (_m *MockManagerClient) WatchNodeChanged(ctx context.Context) <-chan map[int64]*sessionutil.SessionRaw { + ret := _m.Called(ctx) + + var r0 <-chan map[int64]*sessionutil.SessionRaw + if rf, ok := ret.Get(0).(func(context.Context) <-chan map[int64]*sessionutil.SessionRaw); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan map[int64]*sessionutil.SessionRaw) + } + } + + return r0 +} + +// MockManagerClient_WatchNodeChanged_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchNodeChanged' +type MockManagerClient_WatchNodeChanged_Call struct { + *mock.Call +} + +// WatchNodeChanged is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockManagerClient_Expecter) WatchNodeChanged(ctx interface{}) *MockManagerClient_WatchNodeChanged_Call { + return &MockManagerClient_WatchNodeChanged_Call{Call: _e.mock.On("WatchNodeChanged", ctx)} +} + +func (_c *MockManagerClient_WatchNodeChanged_Call) Run(run func(ctx context.Context)) *MockManagerClient_WatchNodeChanged_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockManagerClient_WatchNodeChanged_Call) Return(_a0 <-chan map[int64]*sessionutil.SessionRaw) *MockManagerClient_WatchNodeChanged_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockManagerClient_WatchNodeChanged_Call) RunAndReturn(run func(context.Context) <-chan map[int64]*sessionutil.SessionRaw) *MockManagerClient_WatchNodeChanged_Call { + _c.Call.Return(run) + return _c +} + +// NewMockManagerClient creates a new instance of MockManagerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockManagerClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockManagerClient { + mock := &MockManagerClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proto/streaming.proto b/internal/proto/streaming.proto index c6d2d7c4fe11d..2ed98d7d3a8fe 100644 --- a/internal/proto/streaming.proto +++ b/internal/proto/streaming.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package milvus.proto.log; +package milvus.proto.streaming; option go_package = "github.com/milvus-io/milvus/internal/proto/streamingpb"; @@ -18,23 +18,144 @@ message MessageID { // Message is the basic unit of communication between publisher and consumer. message Message { - bytes payload = 1; // message body - map properties = 2; // message properties + bytes payload = 1; // message body + map properties = 2; // message properties } -// PChannelInfo is the information of a pchannel info. +// PChannelInfo is the information of a pchannel info, should only keep the basic info of a pchannel. +// It's used in many rpc and meta, so keep it simple. message PChannelInfo { - string name = 1; // channel name - int64 term = 2; // A monotonic increasing term, every time the channel is recovered or moved to another streamingnode, the term will increase by meta server. + string name = 1; // channel name + int64 term = + 2; // A monotonic increasing term, every time the channel is recovered or moved to another streamingnode, the term will increase by meta server. +} + +// PChannelMetaHistory is the history meta information of a pchannel, should only keep the data that is necessary to persistent. +message PChannelMetaHistory { + int64 term = 1; // term when server assigned. + StreamingNodeInfo node = + 2; // streaming node that the channel is assigned to. +} + +// PChannelMetaState +enum PChannelMetaState { + PCHANNEL_META_STATE_UNKNOWN = 0; // should never used. + PCHANNEL_META_STATE_UNINITIALIZED = + 1; // channel is uninitialized, never assgined to any streaming node. + PCHANNEL_META_STATE_ASSIGNING = + 2; // new term is allocated, but not determined to be assgined. + PCHANNEL_META_STATE_ASSIGNED = + 3; // channel is assigned to a streaming node. + PCHANNEL_META_STATE_UNAVAILABLE = + 4; // channel is unavailable at this term. +} + +// PChannelMeta is the meta information of a pchannel, should only keep the data that is necessary to persistent. +// It's only used in meta, so do not use it in rpc. +message PChannelMeta { + PChannelInfo channel = 1; // keep the meta info that current assigned to. + StreamingNodeInfo node = 2; // nil if channel is not uninitialized. + PChannelMetaState state = 3; // state of the channel. + repeated PChannelMetaHistory histories = + 4; // keep the meta info history that used to be assigned to. +} + +// VersionPair is the version pair of global and local. +message VersionPair { + int64 global = 1; + int64 local = 2; +} + +// +// Milvus Service +// + +service StreamingCoordStateService { + rpc GetComponentStates(milvus.GetComponentStatesRequest) + returns (milvus.ComponentStates) { + } +} + +service StreamingNodeStateService { + rpc GetComponentStates(milvus.GetComponentStatesRequest) + returns (milvus.ComponentStates) { + } +} + +// +// StreamingCoordAssignmentService +// + +// StreamingCoordAssignmentService is the global log management service. +// Server: log coord. Running on every log node. +// Client: all log publish/consuming node. +service StreamingCoordAssignmentService { + // AssignmentDiscover is used to discover all log nodes managed by the streamingcoord. + // Channel assignment information will be pushed to client by stream. + rpc AssignmentDiscover(stream AssignmentDiscoverRequest) + returns (stream AssignmentDiscoverResponse) { + } +} + +// AssignmentDiscoverRequest is the request of Discovery +message AssignmentDiscoverRequest { + oneof command { + ReportAssignmentErrorRequest report_error = + 1; // report streaming error, trigger reassign right now. + CloseAssignmentDiscoverRequest close = 2; // close the stream. + } +} + +// ReportAssignmentErrorRequest is the request to report assignment error happens. +message ReportAssignmentErrorRequest { + PChannelInfo pchannel = 1; // channel + StreamingError err = 2; // error happend on log node +} + +// CloseAssignmentDiscoverRequest is the request to close the stream. +message CloseAssignmentDiscoverRequest { +} + +// AssignmentDiscoverResponse is the response of Discovery +message AssignmentDiscoverResponse { + oneof response { + FullStreamingNodeAssignmentWithVersion full_assignment = + 1; // all assignment info. + // TODO: may be support partial assignment info in future. + CloseAssignmentDiscoverResponse close = 2; + } +} + +// FullStreamingNodeAssignmentWithVersion is the full assignment info of a log node with version. +message FullStreamingNodeAssignmentWithVersion { + VersionPair version = 1; + repeated StreamingNodeAssignment assignments = 2; +} + +message CloseAssignmentDiscoverResponse { +} + +// StreamingNodeInfo is the information of a streaming node. +message StreamingNodeInfo { + int64 server_id = 1; + string address = 2; +} + +// StreamingNodeAssignment is the assignment info of a streaming node. +message StreamingNodeAssignment { + StreamingNodeInfo node = 1; + repeated PChannelInfo channels = 2; } // DeliverPolicy is the policy to deliver message. message DeliverPolicy { oneof policy { - google.protobuf.Empty all = 1; // deliver all messages. - google.protobuf.Empty latest = 2; // deliver the latest message. - MessageID start_from = 3; // deliver message from this message id. [startFrom, ...] - MessageID start_after = 4; // deliver message after this message id. (startAfter, ...] + google.protobuf.Empty all = 1; // deliver all messages. + google.protobuf.Empty latest = 2; // deliver the latest message. + MessageID start_from = + 3; // deliver message from this message id. [startFrom, ...] + MessageID start_after = + 4; // deliver message after this message id. (startAfter, ...] } } @@ -49,33 +170,35 @@ message DeliverFilter { // DeliverFilterTimeTickGT is the filter to deliver message with time tick greater than this value. message DeliverFilterTimeTickGT { - uint64 time_tick = 1; // deliver message with time tick greater than this value. + uint64 time_tick = + 1; // deliver message with time tick greater than this value. } // DeliverFilterTimeTickGTE is the filter to deliver message with time tick greater than or equal to this value. message DeliverFilterTimeTickGTE { - uint64 time_tick = 1; // deliver message with time tick greater than or equal to this value. + uint64 time_tick = + 1; // deliver message with time tick greater than or equal to this value. } // DeliverFilterVChannel is the filter to deliver message with vchannel name. message DeliverFilterVChannel { - string vchannel = 1; // deliver message with vchannel name. + string vchannel = 1; // deliver message with vchannel name. } // StreamingCode is the error code for log internal component. enum StreamingCode { STREAMING_CODE_OK = 0; - STREAMING_CODE_CHANNEL_EXIST = 1; // channel already exist - STREAMING_CODE_CHANNEL_NOT_EXIST = 2; // channel not exist - STREAMING_CODE_CHANNEL_FENCED = 3; // channel is fenced - STREAMING_CODE_ON_SHUTDOWN = 4; // component is on shutdown - STREAMING_CODE_INVALID_REQUEST_SEQ = 5; // invalid request sequence - STREAMING_CODE_UNMATCHED_CHANNEL_TERM = 6; // unmatched channel term - STREAMING_CODE_IGNORED_OPERATION = 7; // ignored operation - STREAMING_CODE_INNER = 8; // underlying service failure. - STREAMING_CODE_EOF = 9; // end of stream, generated by grpc status. - STREAMING_CODE_INVAILD_ARGUMENT = 10; // invalid argument - STREAMING_CODE_UNKNOWN = 999; // unknown error + STREAMING_CODE_CHANNEL_EXIST = 1; // channel already exist + STREAMING_CODE_CHANNEL_NOT_EXIST = 2; // channel not exist + STREAMING_CODE_CHANNEL_FENCED = 3; // channel is fenced + STREAMING_CODE_ON_SHUTDOWN = 4; // component is on shutdown + STREAMING_CODE_INVALID_REQUEST_SEQ = 5; // invalid request sequence + STREAMING_CODE_UNMATCHED_CHANNEL_TERM = 6; // unmatched channel term + STREAMING_CODE_IGNORED_OPERATION = 7; // ignored operation + STREAMING_CODE_INNER = 8; // underlying service failure. + STREAMING_CODE_EOF = 9; // end of stream, generated by grpc status. + STREAMING_CODE_INVAILD_ARGUMENT = 10; // invalid argument + STREAMING_CODE_UNKNOWN = 999; // unknown error } // StreamingError is the error type for log internal component. @@ -84,7 +207,6 @@ message StreamingError { string cause = 2; } - // // StreamingNodeHandlerService // @@ -101,7 +223,8 @@ service StreamingNodeHandlerService { // Error: // If channel isn't assign to this log node, the RPC will return error CHANNEL_NOT_EXIST. // If channel is moving away to other log node, the RPC will return error CHANNEL_FENCED. - rpc Produce(stream ProduceRequest) returns (stream ProduceResponse) {}; + rpc Produce(stream ProduceRequest) returns (stream ProduceResponse) { + }; // Consume is a server streaming RPC to receive messages from a channel. // All message after given startMessageID and excluding will be sent to the client by stream. @@ -109,7 +232,8 @@ service StreamingNodeHandlerService { // Error: // If channel isn't assign to this log node, the RPC will return error CHANNEL_NOT_EXIST. // If channel is moving away to other log node, the RPC will return error CHANNEL_FENCED. - rpc Consume(stream ConsumeRequest) returns (stream ConsumeResponse) {}; + rpc Consume(stream ConsumeRequest) returns (stream ConsumeResponse) { + }; } // ProduceRequest is the request of the Produce RPC. @@ -129,8 +253,8 @@ message CreateProducerRequest { // ProduceMessageRequest is the request of the Produce RPC. message ProduceMessageRequest { - int64 request_id = 1; // request id for reply. - Message message = 2; // message to be sent. + int64 request_id = 1; // request id for reply. + Message message = 2; // message to be sent. } // CloseProducerRequest is the request of the CloseProducer RPC. @@ -149,8 +273,9 @@ message ProduceResponse { // CreateProducerResponse is the result of the CreateProducer RPC. message CreateProducerResponse { - int64 producer_id = 1; // A unique producer id on streamingnode for this producer in streamingnode lifetime. - // Is used to identify the producer in streamingnode for other unary grpc call at producer level. + int64 producer_id = + 1; // A unique producer id on streamingnode for this producer in streamingnode lifetime. + // Is used to identify the producer in streamingnode for other unary grpc call at producer level. } message ProduceMessageResponse { @@ -163,7 +288,7 @@ message ProduceMessageResponse { // ProduceMessageResponseResult is the result of the produce message streaming RPC. message ProduceMessageResponseResult { - MessageID id = 1; // the offset of the message in the channel + MessageID id = 1; // the offset of the message in the channel } // CloseProducerResponse is the result of the CloseProducer RPC. @@ -187,8 +312,8 @@ message CloseConsumerRequest { // CreateConsumerRequest is passed in the header of stream. message CreateConsumerRequest { PChannelInfo pchannel = 1; - DeliverPolicy deliver_policy = 2; // deliver policy. - repeated DeliverFilter deliver_filters = 3; // deliver filter. + DeliverPolicy deliver_policy = 2; // deliver policy. + repeated DeliverFilter deliver_filters = 3; // deliver filter. } // ConsumeResponse is the reponse of the Consume RPC. @@ -204,8 +329,8 @@ message CreateConsumerResponse { } message ConsumeMessageReponse { - MessageID id = 1; // message id of message. - Message message = 2; // message to be consumed. + MessageID id = 1; // message id of message. + Message message = 2; // message to be consumed. } message CloseConsumerResponse { @@ -223,7 +348,9 @@ service StreamingNodeManagerService { // Block until the channel assignd is ready to read or write on the log node. // Error: // If the channel already exists, return error with code CHANNEL_EXIST. - rpc Assign(StreamingNodeManagerAssignRequest) returns (StreamingNodeManagerAssignResponse) {}; + rpc Assign(StreamingNodeManagerAssignRequest) + returns (StreamingNodeManagerAssignResponse) { + }; // Remove is unary RPC to remove a channel on a log node. // Data of the channel on flying would be sent or flused as much as possible. @@ -231,12 +358,16 @@ service StreamingNodeManagerService { // New incoming request of handler of this channel will be rejected with special error. // Error: // If the channel does not exist, return error with code CHANNEL_NOT_EXIST. - rpc Remove(StreamingNodeManagerRemoveRequest) returns (StreamingNodeManagerRemoveResponse) {}; + rpc Remove(StreamingNodeManagerRemoveRequest) + returns (StreamingNodeManagerRemoveResponse) { + }; // rpc CollectStatus() ... // CollectStatus is unary RPC to collect all avaliable channel info and load balance info on a log node. // Used to recover channel info on log coord, collect balance info and health check. - rpc CollectStatus(StreamingNodeManagerCollectStatusRequest) returns (StreamingNodeManagerCollectStatusResponse) {}; + rpc CollectStatus(StreamingNodeManagerCollectStatusRequest) + returns (StreamingNodeManagerCollectStatusResponse) { + }; } // StreamingManagerAssignRequest is the request message of Assign RPC. @@ -251,7 +382,8 @@ message StreamingNodeManagerRemoveRequest { PChannelInfo pchannel = 1; } -message StreamingNodeManagerRemoveResponse {} +message StreamingNodeManagerRemoveResponse { +} message StreamingNodeManagerCollectStatusRequest { } diff --git a/internal/streamingcoord/server/balancer/balance_timer.go b/internal/streamingcoord/server/balancer/balance_timer.go new file mode 100644 index 0000000000000..ff6ee4ba24da0 --- /dev/null +++ b/internal/streamingcoord/server/balancer/balance_timer.go @@ -0,0 +1,53 @@ +package balancer + +import ( + "time" + + "github.com/cenkalti/backoff/v4" + + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// newBalanceTimer creates a new balanceTimer +func newBalanceTimer() *balanceTimer { + return &balanceTimer{ + backoff: backoff.NewExponentialBackOff(), + newIncomingBackOff: false, + } +} + +// balanceTimer is a timer for balance operation +type balanceTimer struct { + backoff *backoff.ExponentialBackOff + newIncomingBackOff bool + enableBackoff bool +} + +// EnableBackoffOrNot enables or disables backoff +func (t *balanceTimer) EnableBackoff() { + t.enableBackoff = true + t.newIncomingBackOff = true +} + +// DisableBackoff disables backoff +func (t *balanceTimer) DisableBackoff() { + t.enableBackoff = false +} + +// NextTimer returns the next timer and the duration of the timer +func (t *balanceTimer) NextTimer() (<-chan time.Time, time.Duration) { + if !t.enableBackoff { + balanceInterval := paramtable.Get().StreamingCoordCfg.AutoBalanceTriggerInterval.GetAsDurationByParse() + return time.After(balanceInterval), balanceInterval + } + if t.newIncomingBackOff { + t.newIncomingBackOff = false + // reconfig backoff + t.backoff.InitialInterval = paramtable.Get().StreamingCoordCfg.AutoBalanceBackoffInitialInterval.GetAsDurationByParse() + t.backoff.Multiplier = paramtable.Get().StreamingCoordCfg.AutoBalanceBackoffMultiplier.GetAsFloat() + t.backoff.MaxInterval = paramtable.Get().StreamingCoordCfg.AutoBalanceTriggerInterval.GetAsDurationByParse() + t.backoff.Reset() + } + nextBackoff := t.backoff.NextBackOff() + return time.After(nextBackoff), nextBackoff +} diff --git a/internal/streamingcoord/server/balancer/balancer.go b/internal/streamingcoord/server/balancer/balancer.go new file mode 100644 index 0000000000000..cd78f430e7e86 --- /dev/null +++ b/internal/streamingcoord/server/balancer/balancer.go @@ -0,0 +1,28 @@ +package balancer + +import ( + "context" + + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var _ Balancer = (*balancerImpl)(nil) + +// Balancer is a load balancer to balance the load of log node. +// Given the balance result to assign or remove channels to corresponding log node. +// Balancer is a local component, it should promise all channel can be assigned, and reach the final consistency. +// Balancer should be thread safe. +type Balancer interface { + // WatchBalanceResult watches the balance result. + WatchBalanceResult(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error + + // MarkAsAvailable marks the pchannels as available, and trigger a rebalance. + MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error + + // Trigger is a hint to trigger a balance. + Trigger(ctx context.Context) error + + // Close close the balancer. + Close() +} diff --git a/internal/streamingcoord/server/balancer/balancer_impl.go b/internal/streamingcoord/server/balancer/balancer_impl.go new file mode 100644 index 0000000000000..d56dc236b41db --- /dev/null +++ b/internal/streamingcoord/server/balancer/balancer_impl.go @@ -0,0 +1,277 @@ +package balancer + +import ( + "context" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel" + "github.com/milvus-io/milvus/internal/streamingnode/client/manager" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/lifetime" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// RecoverBalancer recover the balancer working. +func RecoverBalancer( + ctx context.Context, + policy string, + streamingNodeManager manager.ManagerClient, + incomingNewChannel ...string, // Concurrent incoming new channel directly from the configuration. + // we should add a rpc interface for creating new incoming new channel. +) (Balancer, error) { + // Recover the channel view from catalog. + manager, err := channel.RecoverChannelManager(ctx, incomingNewChannel...) + if err != nil { + return nil, errors.Wrap(err, "fail to recover channel manager") + } + b := &balancerImpl{ + lifetime: lifetime.NewLifetime(lifetime.Working), + logger: log.With(zap.String("policy", policy)), + streamingNodeManager: streamingNodeManager, // TODO: fill it up. + channelMetaManager: manager, + policy: mustGetPolicy(policy), + reqCh: make(chan *request, 5), + backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), + } + go b.execute() + return b, nil +} + +// balancerImpl is a implementation of Balancer. +type balancerImpl struct { + lifetime lifetime.Lifetime[lifetime.State] + logger *log.MLogger + streamingNodeManager manager.ManagerClient + channelMetaManager *channel.ChannelManager + policy Policy // policy is the balance policy, TODO: should be dynamic in future. + reqCh chan *request // reqCh is the request channel, send the operation to background task. + backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] // backgroundTaskNotifier is used to conmunicate with the background task. +} + +// WatchBalanceResult watches the balance result. +func (b *balancerImpl) WatchBalanceResult(ctx context.Context, cb func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error) error { + if b.lifetime.Add(lifetime.IsWorking) != nil { + return status.NewOnShutdownError("balancer is closing") + } + defer b.lifetime.Done() + return b.channelMetaManager.WatchAssignmentResult(ctx, cb) +} + +func (b *balancerImpl) MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error { + if b.lifetime.Add(lifetime.IsWorking) != nil { + return status.NewOnShutdownError("balancer is closing") + } + defer b.lifetime.Done() + + return b.sendRequestAndWaitFinish(ctx, newOpMarkAsUnavailable(ctx, pChannels)) +} + +// Trigger trigger a re-balance. +func (b *balancerImpl) Trigger(ctx context.Context) error { + if b.lifetime.Add(lifetime.IsWorking) != nil { + return status.NewOnShutdownError("balancer is closing") + } + defer b.lifetime.Done() + + return b.sendRequestAndWaitFinish(ctx, newOpTrigger(ctx)) +} + +// sendRequestAndWaitFinish send a request to the background task and wait for it to finish. +func (b *balancerImpl) sendRequestAndWaitFinish(ctx context.Context, newReq *request) error { + select { + case <-ctx.Done(): + return ctx.Err() + case b.reqCh <- newReq: + } + return newReq.future.Get() +} + +// Close close the balancer. +func (b *balancerImpl) Close() { + b.lifetime.SetState(lifetime.Stopped) + b.lifetime.Wait() + + b.backgroundTaskNotifier.Cancel() + b.backgroundTaskNotifier.BlockUntilFinish() +} + +// execute the balancer. +func (b *balancerImpl) execute() { + b.logger.Info("balancer start to execute") + defer func() { + b.backgroundTaskNotifier.Finish(struct{}{}) + b.logger.Info("balancer execute finished") + }() + + balanceTimer := newBalanceTimer() + for { + // Wait for next balance trigger. + // Maybe trigger by timer or by request. + nextTimer, nextBalanceInterval := balanceTimer.NextTimer() + b.logger.Info("balance wait", zap.Duration("nextBalanceInterval", nextBalanceInterval)) + select { + case <-b.backgroundTaskNotifier.Context().Done(): + return + case newReq := <-b.reqCh: + newReq.apply(b) + b.applyAllRequest() + case <-nextTimer: + } + + if err := b.balance(b.backgroundTaskNotifier.Context()); err != nil { + if b.backgroundTaskNotifier.Context().Err() != nil { + // balancer is closed. + return + } + b.logger.Warn("fail to apply balance, start a backoff...") + balanceTimer.EnableBackoff() + continue + } + + b.logger.Info("apply balance success") + balanceTimer.DisableBackoff() + } +} + +// applyAllRequest apply all request in the request channel. +func (b *balancerImpl) applyAllRequest() { + for { + select { + case newReq := <-b.reqCh: + newReq.apply(b) + default: + return + } + } +} + +// Trigger a balance of layout. +// Return a nil chan to avoid +// Return a channel to notify the balance trigger again. +func (b *balancerImpl) balance(ctx context.Context) error { + b.logger.Info("start to balance") + pchannelView := b.channelMetaManager.CurrentPChannelsView() + + b.logger.Info("collect all status...") + nodeStatus, err := b.streamingNodeManager.CollectAllStatus(ctx) + if err != nil { + return errors.Wrap(err, "fail to collect all status") + } + + // call the balance strategy to generate the expected layout. + currentLayout := generateCurrentLayout(pchannelView, nodeStatus) + expectedLayout, err := b.policy.Balance(currentLayout) + if err != nil { + return errors.Wrap(err, "fail to balance") + } + + b.logger.Info("balance policy generate result success, try to assign...", zap.Any("expectedLayout", expectedLayout)) + // bookkeeping the meta assignment started. + modifiedChannels, err := b.channelMetaManager.AssignPChannels(ctx, expectedLayout.ChannelAssignment) + if err != nil { + return errors.Wrap(err, "fail to assign pchannels") + } + + if len(modifiedChannels) == 0 { + b.logger.Info("no change of balance result need to be applied") + return nil + } + return b.applyBalanceResultToStreamingNode(ctx, modifiedChannels) +} + +// applyBalanceResultToStreamingNode apply the balance result to streaming node. +func (b *balancerImpl) applyBalanceResultToStreamingNode(ctx context.Context, modifiedChannels map[string]*channel.PChannelMeta) error { + b.logger.Info("balance result need to be applied...", zap.Int("modifiedChannelCount", len(modifiedChannels))) + + // different channel can be execute concurrently. + g, _ := errgroup.WithContext(ctx) + // generate balance operations and applied them. + for _, channel := range modifiedChannels { + channel := channel + g.Go(func() error { + // all history channels should be remove from related nodes. + for _, assignment := range channel.AssignHistories() { + if err := b.streamingNodeManager.Remove(ctx, assignment); err != nil { + b.logger.Warn("fail to remove channel", zap.Any("assignment", assignment)) + return err + } + b.logger.Info("remove channel success", zap.Any("assignment", assignment)) + } + + // assign the channel to the target node. + if err := b.streamingNodeManager.Assign(ctx, channel.CurrentAssignment()); err != nil { + b.logger.Warn("fail to assign channel", zap.Any("assignment", channel.CurrentAssignment())) + return err + } + b.logger.Info("assign channel success", zap.Any("assignment", channel.CurrentAssignment())) + + // bookkeeping the meta assignment done. + if err := b.channelMetaManager.AssignPChannelsDone(ctx, []string{channel.Name()}); err != nil { + b.logger.Warn("fail to bookkeep pchannel assignment done", zap.Any("assignment", channel.CurrentAssignment())) + return err + } + return nil + }) + } + return g.Wait() +} + +// generateCurrentLayout generate layout from all nodes info and meta. +func generateCurrentLayout(channelsInMeta map[string]*channel.PChannelMeta, allNodesStatus map[int64]types.StreamingNodeStatus) (layout CurrentLayout) { + activeRelations := make(map[int64][]types.PChannelInfo, len(allNodesStatus)) + incomingChannels := make([]string, 0) + channelsToNodes := make(map[string]int64, len(channelsInMeta)) + assigned := make(map[int64][]types.PChannelInfo, len(allNodesStatus)) + for _, meta := range channelsInMeta { + if !meta.IsAssigned() { + incomingChannels = append(incomingChannels, meta.Name()) + // dead or expired relationship. + log.Warn("channel is not assigned to any server", + zap.String("channel", meta.Name()), + zap.Int64("term", meta.CurrentTerm()), + zap.Int64("serverID", meta.CurrentServerID()), + zap.String("state", meta.State().String()), + ) + continue + } + if nodeStatus, ok := allNodesStatus[meta.CurrentServerID()]; ok && nodeStatus.IsHealthy() { + // active relationship. + activeRelations[meta.CurrentServerID()] = append(activeRelations[meta.CurrentServerID()], types.PChannelInfo{ + Name: meta.Name(), + Term: meta.CurrentTerm(), + }) + channelsToNodes[meta.Name()] = meta.CurrentServerID() + assigned[meta.CurrentServerID()] = append(assigned[meta.CurrentServerID()], meta.ChannelInfo()) + } else { + incomingChannels = append(incomingChannels, meta.Name()) + // dead or expired relationship. + log.Warn("channel of current server id is not healthy or not alive", + zap.String("channel", meta.Name()), + zap.Int64("term", meta.CurrentTerm()), + zap.Int64("serverID", meta.CurrentServerID()), + zap.Error(nodeStatus.Err), + ) + } + } + + allNodesInfo := make(map[int64]types.StreamingNodeInfo, len(allNodesStatus)) + for serverID, nodeStatus := range allNodesStatus { + // filter out the unhealthy nodes. + if nodeStatus.IsHealthy() { + allNodesInfo[serverID] = nodeStatus.StreamingNodeInfo + } + } + + return CurrentLayout{ + IncomingChannels: incomingChannels, + ChannelsToNodes: channelsToNodes, + AssignedChannels: assigned, + AllNodesInfo: allNodesInfo, + } +} diff --git a/internal/streamingcoord/server/balancer/balancer_test.go b/internal/streamingcoord/server/balancer/balancer_test.go new file mode 100644 index 0000000000000..f495bc9385ee0 --- /dev/null +++ b/internal/streamingcoord/server/balancer/balancer_test.go @@ -0,0 +1,115 @@ +package balancer_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus/internal/mocks/mock_metastore" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/mock_manager" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + _ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestBalancer(t *testing.T) { + paramtable.Init() + + streamingNodeManager := mock_manager.NewMockManagerClient(t) + streamingNodeManager.EXPECT().Assign(mock.Anything, mock.Anything).Return(nil) + streamingNodeManager.EXPECT().Remove(mock.Anything, mock.Anything).Return(nil) + streamingNodeManager.EXPECT().CollectAllStatus(mock.Anything).Return(map[int64]types.StreamingNodeStatus{ + 1: { + StreamingNodeInfo: types.StreamingNodeInfo{ + ServerID: 1, + Address: "localhost:1", + }, + }, + 2: { + StreamingNodeInfo: types.StreamingNodeInfo{ + ServerID: 2, + Address: "localhost:2", + }, + }, + 3: { + StreamingNodeInfo: types.StreamingNodeInfo{ + ServerID: 3, + Address: "localhost:3", + }, + }, + 4: { + StreamingNodeInfo: types.StreamingNodeInfo{ + ServerID: 3, + Address: "localhost:3", + }, + Err: types.ErrStopping, + }, + }, nil) + + catalog := mock_metastore.NewMockStreamingCoordCataLog(t) + resource.InitForTest(resource.OptStreamingCatalog(catalog)) + catalog.EXPECT().ListPChannel(mock.Anything).Unset() + catalog.EXPECT().ListPChannel(mock.Anything).RunAndReturn(func(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { + return []*streamingpb.PChannelMeta{ + { + Channel: &streamingpb.PChannelInfo{ + Name: "test-channel-1", + Term: 1, + }, + State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED, + Node: &streamingpb.StreamingNodeInfo{ServerId: 1}, + }, + { + Channel: &streamingpb.PChannelInfo{ + Name: "test-channel-2", + Term: 1, + }, + State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED, + Node: &streamingpb.StreamingNodeInfo{ServerId: 4}, + }, + { + Channel: &streamingpb.PChannelInfo{ + Name: "test-channel-3", + Term: 2, + }, + State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNING, + Node: &streamingpb.StreamingNodeInfo{ServerId: 2}, + }, + }, nil + }) + catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).Return(nil).Maybe() + + ctx := context.Background() + b, err := balancer.RecoverBalancer(ctx, "pchannel_count_fair", streamingNodeManager) + assert.NoError(t, err) + assert.NotNil(t, b) + defer b.Close() + + b.MarkAsUnavailable(ctx, []types.PChannelInfo{{ + Name: "test-channel-1", + Term: 1, + }}) + b.Trigger(ctx) + + doneErr := errors.New("done") + err = b.WatchBalanceResult(ctx, func(version typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error { + // should one pchannel be assigned to per nodes + nodeIDs := typeutil.NewSet[int64]() + if len(relations) == 3 { + for _, status := range relations { + nodeIDs.Insert(status.Node.ServerID) + } + assert.Equal(t, 3, nodeIDs.Len()) + return doneErr + } + return nil + }) + assert.ErrorIs(t, err, doneErr) +} diff --git a/internal/streamingcoord/server/balancer/channel/manager.go b/internal/streamingcoord/server/balancer/channel/manager.go new file mode 100644 index 0000000000000..4197bff0ab67f --- /dev/null +++ b/internal/streamingcoord/server/balancer/channel/manager.go @@ -0,0 +1,223 @@ +package channel + +import ( + "context" + "sync" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var ErrChannelNotExist = errors.New("channel not exist") + +// RecoverChannelManager creates a new channel manager. +func RecoverChannelManager(ctx context.Context, incomingChannel ...string) (*ChannelManager, error) { + channels, err := recoverFromConfigurationAndMeta(ctx, incomingChannel...) + if err != nil { + return nil, err + } + globalVersion := paramtable.GetNodeID() + return &ChannelManager{ + cond: syncutil.NewContextCond(&sync.Mutex{}), + channels: channels, + version: typeutil.VersionInt64Pair{ + Global: globalVersion, // global version should be keep increasing globally, it's ok to use node id. + Local: 0, + }, + }, nil +} + +// recoverFromConfigurationAndMeta recovers the channel manager from configuration and meta. +func recoverFromConfigurationAndMeta(ctx context.Context, incomingChannel ...string) (map[string]*PChannelMeta, error) { + // Get all channels from meta. + channelMetas, err := resource.Resource().StreamingCatalog().ListPChannel(ctx) + if err != nil { + return nil, err + } + + channels := make(map[string]*PChannelMeta, len(channelMetas)) + for _, channel := range channelMetas { + channels[channel.GetChannel().GetName()] = newPChannelMetaFromProto(channel) + } + + // Get new incoming meta from configuration. + for _, newChannel := range incomingChannel { + if _, ok := channels[newChannel]; !ok { + channels[newChannel] = newPChannelMeta(newChannel) + } + } + return channels, nil +} + +// ChannelManager manages the channels. +// ChannelManager is the `wal` of channel assignment and unassignment. +// Every operation applied to the streaming node should be recorded in ChannelManager first. +type ChannelManager struct { + cond *syncutil.ContextCond + channels map[string]*PChannelMeta + version typeutil.VersionInt64Pair +} + +// CurrentPChannelsView returns the current view of pchannels. +func (cm *ChannelManager) CurrentPChannelsView() map[string]*PChannelMeta { + cm.cond.L.Lock() + defer cm.cond.L.Unlock() + + channels := make(map[string]*PChannelMeta, len(cm.channels)) + for k, v := range cm.channels { + channels[k] = v + } + return channels +} + +// AssignPChannels update the pchannels to servers and return the modified pchannels. +// When the balancer want to assign a pchannel into a new server. +// It should always call this function to update the pchannel assignment first. +// Otherwise, the pchannel assignment tracing is lost at meta. +func (cm *ChannelManager) AssignPChannels(ctx context.Context, pChannelToStreamingNode map[string]types.StreamingNodeInfo) (map[string]*PChannelMeta, error) { + cm.cond.LockAndBroadcast() + defer cm.cond.L.Unlock() + + // modified channels. + pChannelMetas := make([]*streamingpb.PChannelMeta, 0, len(pChannelToStreamingNode)) + for channelName, streamingNode := range pChannelToStreamingNode { + pchannel, ok := cm.channels[channelName] + if !ok { + return nil, ErrChannelNotExist + } + mutablePchannel := pchannel.CopyForWrite() + if mutablePchannel.TryAssignToServerID(streamingNode) { + pChannelMetas = append(pChannelMetas, mutablePchannel.IntoRawMeta()) + } + } + + err := cm.updatePChannelMeta(ctx, pChannelMetas) + if err != nil { + return nil, err + } + + updates := make(map[string]*PChannelMeta, len(pChannelMetas)) + for _, pchannel := range pChannelMetas { + updates[pchannel.GetChannel().GetName()] = newPChannelMetaFromProto(pchannel) + } + return updates, nil +} + +// AssignPChannelsDone clear up the history data of the pchannels and transfer the state into assigned. +// When the balancer want to cleanup the history data of a pchannel. +// It should always remove the pchannel on the server first. +// Otherwise, the pchannel assignment tracing is lost at meta. +func (cm *ChannelManager) AssignPChannelsDone(ctx context.Context, pChannels []string) error { + cm.cond.LockAndBroadcast() + defer cm.cond.L.Unlock() + + // modified channels. + pChannelMetas := make([]*streamingpb.PChannelMeta, 0, len(pChannels)) + for _, channelName := range pChannels { + pchannel, ok := cm.channels[channelName] + if !ok { + return ErrChannelNotExist + } + mutablePChannel := pchannel.CopyForWrite() + mutablePChannel.AssignToServerDone() + pChannelMetas = append(pChannelMetas, mutablePChannel.IntoRawMeta()) + } + + return cm.updatePChannelMeta(ctx, pChannelMetas) +} + +// MarkAsUnavailable mark the pchannels as unavailable. +func (cm *ChannelManager) MarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) error { + cm.cond.LockAndBroadcast() + defer cm.cond.L.Unlock() + + // modified channels. + pChannelMetas := make([]*streamingpb.PChannelMeta, 0, len(pChannels)) + for _, channel := range pChannels { + pchannel, ok := cm.channels[channel.Name] + if !ok { + return ErrChannelNotExist + } + mutablePChannel := pchannel.CopyForWrite() + mutablePChannel.MarkAsUnavailable(channel.Term) + pChannelMetas = append(pChannelMetas, mutablePChannel.IntoRawMeta()) + } + + return cm.updatePChannelMeta(ctx, pChannelMetas) +} + +// updatePChannelMeta updates the pchannel metas. +func (cm *ChannelManager) updatePChannelMeta(ctx context.Context, pChannelMetas []*streamingpb.PChannelMeta) error { + if len(pChannelMetas) == 0 { + return nil + } + if err := resource.Resource().StreamingCatalog().SavePChannels(ctx, pChannelMetas); err != nil { + return errors.Wrap(err, "update meta at catalog") + } + + // update in-memory copy and increase the version. + for _, pchannel := range pChannelMetas { + cm.channels[pchannel.GetChannel().GetName()] = newPChannelMetaFromProto(pchannel) + } + cm.version.Local++ + // update metrics. + metrics.StreamingCoordAssignmentVersion.WithLabelValues( + paramtable.GetStringNodeID(), + ).Set(float64(cm.version.Local)) + return nil +} + +func (cm *ChannelManager) WatchAssignmentResult(ctx context.Context, cb func(version typeutil.VersionInt64Pair, assignments []types.PChannelInfoAssigned) error) error { + // push the first balance result to watcher callback function if balance result is ready. + version, err := cm.applyAssignments(cb) + if err != nil { + return err + } + for { + // wait for version change, and apply the latest assignment to callback. + if err := cm.waitChanges(ctx, version); err != nil { + return err + } + if version, err = cm.applyAssignments(cb); err != nil { + return err + } + } +} + +// applyAssignments applies the assignments. +func (cm *ChannelManager) applyAssignments(cb func(version typeutil.VersionInt64Pair, assignments []types.PChannelInfoAssigned) error) (typeutil.VersionInt64Pair, error) { + cm.cond.L.Lock() + assignments, version := cm.getAssignments() + cm.cond.L.Unlock() + return version, cb(version, assignments) +} + +// getAssignments returns the current assignments. +func (cm *ChannelManager) getAssignments() ([]types.PChannelInfoAssigned, typeutil.VersionInt64Pair) { + assignments := make([]types.PChannelInfoAssigned, 0, len(cm.channels)) + for _, c := range cm.channels { + if c.IsAssigned() { + assignments = append(assignments, c.CurrentAssignment()) + } + } + return assignments, cm.version +} + +// waitChanges waits for the layout to be updated. +func (cm *ChannelManager) waitChanges(ctx context.Context, version typeutil.Version) error { + cm.cond.L.Lock() + for version.EQ(cm.version) { + if err := cm.cond.Wait(ctx); err != nil { + return err + } + } + cm.cond.L.Unlock() + return nil +} diff --git a/internal/streamingcoord/server/balancer/channel/manager_test.go b/internal/streamingcoord/server/balancer/channel/manager_test.go new file mode 100644 index 0000000000000..1e4242cb4f2d5 --- /dev/null +++ b/internal/streamingcoord/server/balancer/channel/manager_test.go @@ -0,0 +1,143 @@ +package channel + +import ( + "context" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus/internal/mocks/mock_metastore" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestChannelManager(t *testing.T) { + catalog := mock_metastore.NewMockStreamingCoordCataLog(t) + resource.InitForTest(resource.OptStreamingCatalog(catalog)) + + ctx := context.Background() + // Test recover failure. + catalog.EXPECT().ListPChannel(mock.Anything).Return(nil, errors.New("recover failure")) + m, err := RecoverChannelManager(ctx) + assert.Nil(t, m) + assert.Error(t, err) + + catalog.EXPECT().ListPChannel(mock.Anything).Unset() + catalog.EXPECT().ListPChannel(mock.Anything).RunAndReturn(func(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { + return []*streamingpb.PChannelMeta{ + { + Channel: &streamingpb.PChannelInfo{ + Name: "test-channel", + Term: 1, + }, + Node: &streamingpb.StreamingNodeInfo{ + ServerId: 1, + }, + }, + }, nil + }) + m, err = RecoverChannelManager(ctx) + assert.NotNil(t, m) + assert.NoError(t, err) + + // Test save meta failure + catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).Return(errors.New("save meta failure")) + modified, err := m.AssignPChannels(ctx, map[string]types.StreamingNodeInfo{"test-channel": {ServerID: 2}}) + assert.Nil(t, modified) + assert.Error(t, err) + err = m.AssignPChannelsDone(ctx, []string{"test-channel"}) + assert.Error(t, err) + err = m.MarkAsUnavailable(ctx, []types.PChannelInfo{{ + Name: "test-channel", + Term: 2, + }}) + assert.Error(t, err) + + // Test update non exist pchannel + modified, err = m.AssignPChannels(ctx, map[string]types.StreamingNodeInfo{"non-exist-channel": {ServerID: 2}}) + assert.Nil(t, modified) + assert.ErrorIs(t, err, ErrChannelNotExist) + err = m.AssignPChannelsDone(ctx, []string{"non-exist-channel"}) + assert.ErrorIs(t, err, ErrChannelNotExist) + err = m.MarkAsUnavailable(ctx, []types.PChannelInfo{{ + Name: "non-exist-channel", + Term: 2, + }}) + assert.ErrorIs(t, err, ErrChannelNotExist) + + // Test success. + catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).Unset() + catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).Return(nil) + modified, err = m.AssignPChannels(ctx, map[string]types.StreamingNodeInfo{"test-channel": {ServerID: 2}}) + assert.NotNil(t, modified) + assert.NoError(t, err) + assert.Len(t, modified, 1) + err = m.AssignPChannelsDone(ctx, []string{"test-channel"}) + assert.NoError(t, err) + err = m.MarkAsUnavailable(ctx, []types.PChannelInfo{{ + Name: "test-channel", + Term: 2, + }}) + assert.NoError(t, err) + + view := m.CurrentPChannelsView() + assert.NotNil(t, view) + assert.Len(t, view, 1) + assert.NotNil(t, view["test-channel"]) +} + +func TestChannelManagerWatch(t *testing.T) { + catalog := mock_metastore.NewMockStreamingCoordCataLog(t) + resource.InitForTest(resource.OptStreamingCatalog(catalog)) + + catalog.EXPECT().ListPChannel(mock.Anything).Unset() + catalog.EXPECT().ListPChannel(mock.Anything).RunAndReturn(func(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { + return []*streamingpb.PChannelMeta{ + { + Channel: &streamingpb.PChannelInfo{ + Name: "test-channel", + Term: 1, + }, + Node: &streamingpb.StreamingNodeInfo{ + ServerId: 1, + }, + State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED, + }, + }, nil + }) + catalog.EXPECT().SavePChannels(mock.Anything, mock.Anything).Return(nil) + + manager, err := RecoverChannelManager(context.Background()) + assert.NoError(t, err) + done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + + called := make(chan struct{}, 1) + go func() { + defer close(done) + err := manager.WatchAssignmentResult(ctx, func(version typeutil.VersionInt64Pair, assignments []types.PChannelInfoAssigned) error { + select { + case called <- struct{}{}: + default: + } + return nil + }) + assert.ErrorIs(t, err, context.Canceled) + }() + + manager.AssignPChannels(ctx, map[string]types.StreamingNodeInfo{"test-channel": {ServerID: 2}}) + manager.AssignPChannelsDone(ctx, []string{"test-channel"}) + + <-called + manager.MarkAsUnavailable(ctx, []types.PChannelInfo{{ + Name: "test-channel", + Term: 2, + }}) + <-called + cancel() + <-done +} diff --git a/internal/streamingcoord/server/balancer/channel/pchannel.go b/internal/streamingcoord/server/balancer/channel/pchannel.go new file mode 100644 index 0000000000000..e4b79d1fafd4f --- /dev/null +++ b/internal/streamingcoord/server/balancer/channel/pchannel.go @@ -0,0 +1,150 @@ +package channel + +import ( + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/util/streamingutil/typeconverter" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +// newPChannelMeta creates a new PChannelMeta. +func newPChannelMeta(name string) *PChannelMeta { + return &PChannelMeta{ + inner: &streamingpb.PChannelMeta{ + Channel: &streamingpb.PChannelInfo{ + Name: name, + Term: 1, + }, + Node: nil, + State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNINITIALIZED, + Histories: make([]*streamingpb.PChannelMetaHistory, 0), + }, + } +} + +// newPChannelMetaFromProto creates a new PChannelMeta from proto. +func newPChannelMetaFromProto(channel *streamingpb.PChannelMeta) *PChannelMeta { + return &PChannelMeta{ + inner: channel, + } +} + +// PChannelMeta is the read only version of PChannelInfo, to be used in balancer, +// If you need to update PChannelMeta, please use CopyForWrite to get mutablePChannel. +type PChannelMeta struct { + inner *streamingpb.PChannelMeta +} + +// Name returns the name of the channel. +func (c *PChannelMeta) Name() string { + return c.inner.GetChannel().GetName() +} + +// ChannelInfo returns the channel info. +func (c *PChannelMeta) ChannelInfo() types.PChannelInfo { + return typeconverter.NewPChannelInfoFromProto(c.inner.Channel) +} + +// Term returns the current term of the channel. +func (c *PChannelMeta) CurrentTerm() int64 { + return c.inner.GetChannel().GetTerm() +} + +// CurrentServerID returns the server id of the channel. +// If the channel is not assigned to any server, return -1. +func (c *PChannelMeta) CurrentServerID() int64 { + return c.inner.GetNode().GetServerId() +} + +// CurrentAssignment returns the current assignment of the channel. +func (c *PChannelMeta) CurrentAssignment() types.PChannelInfoAssigned { + return types.PChannelInfoAssigned{ + Channel: typeconverter.NewPChannelInfoFromProto(c.inner.Channel), + Node: typeconverter.NewStreamingNodeInfoFromProto(c.inner.Node), + } +} + +// AssignHistories returns the history of the channel assignment. +func (c *PChannelMeta) AssignHistories() []types.PChannelInfoAssigned { + history := make([]types.PChannelInfoAssigned, 0, len(c.inner.Histories)) + for _, h := range c.inner.Histories { + history = append(history, types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{ + Name: c.inner.GetChannel().GetName(), + Term: h.Term, + }, + Node: typeconverter.NewStreamingNodeInfoFromProto(h.Node), + }) + } + return history +} + +// IsAssigned returns if the channel is assigned to a server. +func (c *PChannelMeta) IsAssigned() bool { + return c.inner.State == streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED +} + +// State returns the state of the channel. +func (c *PChannelMeta) State() streamingpb.PChannelMetaState { + return c.inner.State +} + +// CopyForWrite returns mutablePChannel to modify pchannel +// but didn't affect other replicas. +func (c *PChannelMeta) CopyForWrite() *mutablePChannel { + return &mutablePChannel{ + PChannelMeta: &PChannelMeta{ + inner: proto.Clone(c.inner).(*streamingpb.PChannelMeta), + }, + } +} + +// mutablePChannel is a mutable version of PChannel. +// use to update the channel info. +type mutablePChannel struct { + *PChannelMeta +} + +// TryAssignToServerID assigns the channel to a server. +func (m *mutablePChannel) TryAssignToServerID(streamingNode types.StreamingNodeInfo) bool { + if m.CurrentServerID() == streamingNode.ServerID && m.inner.State == streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED { + // if the channel is already assigned to the server, return false. + return false + } + if m.inner.State != streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNINITIALIZED { + // if the channel is already initialized, add the history. + m.inner.Histories = append(m.inner.Histories, &streamingpb.PChannelMetaHistory{ + Term: m.inner.Channel.Term, + Node: m.inner.Node, + }) + } + + // otherwise update the channel into assgining state. + m.inner.Channel.Term++ + m.inner.Node = typeconverter.NewProtoFromStreamingNodeInfo(streamingNode) + m.inner.State = streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNING + return true +} + +// AssignToServerDone assigns the channel to the server done. +func (m *mutablePChannel) AssignToServerDone() { + if m.inner.State == streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNING { + m.inner.Histories = make([]*streamingpb.PChannelMetaHistory, 0) + m.inner.State = streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED + } +} + +// MarkAsUnavailable marks the channel as unavailable. +func (m *mutablePChannel) MarkAsUnavailable(term int64) { + if m.inner.State == streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED && m.CurrentTerm() == term { + m.inner.State = streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNAVAILABLE + } +} + +// IntoRawMeta returns the raw meta, no longger available after call. +func (m *mutablePChannel) IntoRawMeta() *streamingpb.PChannelMeta { + c := m.PChannelMeta + m.PChannelMeta = nil + return c.inner +} diff --git a/internal/streamingcoord/server/balancer/channel/pchannel_test.go b/internal/streamingcoord/server/balancer/channel/pchannel_test.go new file mode 100644 index 0000000000000..a5a0b85a4d1a9 --- /dev/null +++ b/internal/streamingcoord/server/balancer/channel/pchannel_test.go @@ -0,0 +1,107 @@ +package channel + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +func TestPChannel(t *testing.T) { + pchannel := newPChannelMetaFromProto(&streamingpb.PChannelMeta{ + Channel: &streamingpb.PChannelInfo{ + Name: "test-channel", + Term: 1, + }, + Node: &streamingpb.StreamingNodeInfo{ + ServerId: 123, + }, + State: streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNINITIALIZED, + }) + assert.Equal(t, "test-channel", pchannel.Name()) + assert.Equal(t, int64(1), pchannel.CurrentTerm()) + assert.Equal(t, int64(123), pchannel.CurrentServerID()) + assert.Equal(t, streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNINITIALIZED, pchannel.State()) + assert.False(t, pchannel.IsAssigned()) + assert.Empty(t, pchannel.AssignHistories()) + assert.Equal(t, types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{ + Name: "test-channel", + Term: 1, + }, + Node: types.StreamingNodeInfo{ + ServerID: 123, + }, + }, pchannel.CurrentAssignment()) + + pchannel = newPChannelMeta("test-channel") + assert.Equal(t, "test-channel", pchannel.Name()) + assert.Equal(t, int64(1), pchannel.CurrentTerm()) + assert.Empty(t, pchannel.AssignHistories()) + assert.False(t, pchannel.IsAssigned()) + + // Test CopyForWrite() + mutablePChannel := pchannel.CopyForWrite() + assert.NotNil(t, mutablePChannel) + + // Test AssignToServerID() + newServerID := types.StreamingNodeInfo{ + ServerID: 456, + } + assert.True(t, mutablePChannel.TryAssignToServerID(newServerID)) + updatedChannelInfo := newPChannelMetaFromProto(mutablePChannel.IntoRawMeta()) + + assert.Equal(t, "test-channel", pchannel.Name()) + assert.Equal(t, int64(1), pchannel.CurrentTerm()) + assert.Empty(t, pchannel.AssignHistories()) + + assert.Equal(t, "test-channel", updatedChannelInfo.Name()) + assert.Equal(t, int64(2), updatedChannelInfo.CurrentTerm()) + assert.Equal(t, int64(456), updatedChannelInfo.CurrentServerID()) + assert.Empty(t, pchannel.AssignHistories()) + assert.False(t, updatedChannelInfo.IsAssigned()) + assert.Equal(t, streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNING, updatedChannelInfo.State()) + + mutablePChannel = updatedChannelInfo.CopyForWrite() + + mutablePChannel.TryAssignToServerID(types.StreamingNodeInfo{ServerID: 789}) + updatedChannelInfo = newPChannelMetaFromProto(mutablePChannel.IntoRawMeta()) + assert.Equal(t, "test-channel", updatedChannelInfo.Name()) + assert.Equal(t, int64(3), updatedChannelInfo.CurrentTerm()) + assert.Equal(t, int64(789), updatedChannelInfo.CurrentServerID()) + assert.Len(t, updatedChannelInfo.AssignHistories(), 1) + assert.Equal(t, "test-channel", updatedChannelInfo.AssignHistories()[0].Channel.Name) + assert.Equal(t, int64(2), updatedChannelInfo.AssignHistories()[0].Channel.Term) + assert.Equal(t, int64(456), updatedChannelInfo.AssignHistories()[0].Node.ServerID) + assert.False(t, updatedChannelInfo.IsAssigned()) + assert.Equal(t, streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNING, updatedChannelInfo.State()) + + // Test AssignToServerDone + mutablePChannel = updatedChannelInfo.CopyForWrite() + mutablePChannel.AssignToServerDone() + updatedChannelInfo = newPChannelMetaFromProto(mutablePChannel.IntoRawMeta()) + assert.Equal(t, "test-channel", updatedChannelInfo.Name()) + assert.Equal(t, int64(3), updatedChannelInfo.CurrentTerm()) + assert.Equal(t, int64(789), updatedChannelInfo.CurrentServerID()) + assert.Len(t, updatedChannelInfo.AssignHistories(), 0) + assert.True(t, updatedChannelInfo.IsAssigned()) + assert.Equal(t, streamingpb.PChannelMetaState_PCHANNEL_META_STATE_ASSIGNED, updatedChannelInfo.State()) + + // Test reassigned + mutablePChannel = updatedChannelInfo.CopyForWrite() + assert.False(t, mutablePChannel.TryAssignToServerID(types.StreamingNodeInfo{ServerID: 789})) + + // Test MarkAsUnavailable + mutablePChannel = updatedChannelInfo.CopyForWrite() + mutablePChannel.MarkAsUnavailable(2) + updatedChannelInfo = newPChannelMetaFromProto(mutablePChannel.IntoRawMeta()) + assert.True(t, updatedChannelInfo.IsAssigned()) + + mutablePChannel = updatedChannelInfo.CopyForWrite() + mutablePChannel.MarkAsUnavailable(3) + updatedChannelInfo = newPChannelMetaFromProto(mutablePChannel.IntoRawMeta()) + assert.False(t, updatedChannelInfo.IsAssigned()) + assert.Equal(t, streamingpb.PChannelMetaState_PCHANNEL_META_STATE_UNAVAILABLE, updatedChannelInfo.State()) +} diff --git a/internal/streamingcoord/server/balancer/policy/init.go b/internal/streamingcoord/server/balancer/policy/init.go new file mode 100644 index 0000000000000..a1ffb14fe89ae --- /dev/null +++ b/internal/streamingcoord/server/balancer/policy/init.go @@ -0,0 +1,7 @@ +package policy + +import "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + +func init() { + balancer.RegisterPolicy(&pchannelCountFairPolicy{}) +} diff --git a/internal/streamingcoord/server/balancer/policy/pchannel_count_fair.go b/internal/streamingcoord/server/balancer/policy/pchannel_count_fair.go new file mode 100644 index 0000000000000..aa7e6daa6b826 --- /dev/null +++ b/internal/streamingcoord/server/balancer/policy/pchannel_count_fair.go @@ -0,0 +1,69 @@ +package policy + +import ( + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +var _ balancer.Policy = &pchannelCountFairPolicy{} + +// pchannelCountFairPolicy is a policy to balance the load of log node by channel count. +// Make sure the channel count of each streaming node is equal or differ by 1. +type pchannelCountFairPolicy struct{} + +func (p *pchannelCountFairPolicy) Name() string { + return "pchannel_count_fair" +} + +func (p *pchannelCountFairPolicy) Balance(currentLayout balancer.CurrentLayout) (expectedLayout balancer.ExpectedLayout, err error) { + if currentLayout.TotalNodes() == 0 { + return balancer.ExpectedLayout{}, errors.New("no available streaming node") + } + + // Get the average and remaining channel count of all streaming node. + avgChannelCount := currentLayout.TotalChannels() / currentLayout.TotalNodes() + remainingChannelCount := currentLayout.TotalChannels() % currentLayout.TotalNodes() + + assignments := make(map[string]types.StreamingNodeInfo, currentLayout.TotalChannels()) + nodesChannelCount := make(map[int64]int, currentLayout.TotalNodes()) + needAssignChannel := currentLayout.IncomingChannels + + // keep the channel already on the node. + for serverID, nodeInfo := range currentLayout.AllNodesInfo { + nodesChannelCount[serverID] = 0 + for i, channelInfo := range currentLayout.AssignedChannels[serverID] { + if i < avgChannelCount { + assignments[channelInfo.Name] = nodeInfo + nodesChannelCount[serverID]++ + } else if i == avgChannelCount && remainingChannelCount > 0 { + assignments[channelInfo.Name] = nodeInfo + nodesChannelCount[serverID]++ + remainingChannelCount-- + } else { + needAssignChannel = append(needAssignChannel, channelInfo.Name) + } + } + } + + // assign the incoming node to the node with least channel count. + for serverID, assignedChannelCount := range nodesChannelCount { + assignCount := 0 + if assignedChannelCount < avgChannelCount { + assignCount = avgChannelCount - assignedChannelCount + } else if assignedChannelCount == avgChannelCount && remainingChannelCount > 0 { + assignCount = 1 + remainingChannelCount-- + } + for i := 0; i < assignCount; i++ { + assignments[needAssignChannel[i]] = currentLayout.AllNodesInfo[serverID] + nodesChannelCount[serverID]++ + } + needAssignChannel = needAssignChannel[assignCount:] + } + + return balancer.ExpectedLayout{ + ChannelAssignment: assignments, + }, nil +} diff --git a/internal/streamingcoord/server/balancer/policy/pchannel_count_fair_test.go b/internal/streamingcoord/server/balancer/policy/pchannel_count_fair_test.go new file mode 100644 index 0000000000000..48c1c2881faa4 --- /dev/null +++ b/internal/streamingcoord/server/balancer/policy/pchannel_count_fair_test.go @@ -0,0 +1,183 @@ +package policy + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +func TestPChannelCountFair(t *testing.T) { + policy := &pchannelCountFairPolicy{} + assert.Equal(t, "pchannel_count_fair", policy.Name()) + expected, err := policy.Balance(balancer.CurrentLayout{ + IncomingChannels: []string{ + "c8", + "c9", + "c10", + }, + AllNodesInfo: map[int64]types.StreamingNodeInfo{ + 1: {ServerID: 1}, + 2: {ServerID: 2}, + 3: {ServerID: 3}, + }, + AssignedChannels: map[int64][]types.PChannelInfo{ + 1: {}, + 2: { + {Name: "c1"}, + {Name: "c3"}, + {Name: "c4"}, + }, + 3: { + {Name: "c2"}, + {Name: "c5"}, + {Name: "c6"}, + {Name: "c7"}, + }, + }, + ChannelsToNodes: map[string]int64{ + "c1": 2, + "c3": 2, + "c4": 2, + "c2": 3, + "c5": 3, + "c6": 3, + "c7": 3, + }, + }) + + assert.Equal(t, 10, len(expected.ChannelAssignment)) + assert.Equal(t, int64(2), expected.ChannelAssignment["c1"].ServerID) + assert.Equal(t, int64(2), expected.ChannelAssignment["c3"].ServerID) + assert.Equal(t, int64(2), expected.ChannelAssignment["c4"].ServerID) + assert.Equal(t, int64(3), expected.ChannelAssignment["c2"].ServerID) + assert.Equal(t, int64(3), expected.ChannelAssignment["c5"].ServerID) + assert.Equal(t, int64(3), expected.ChannelAssignment["c6"].ServerID) + assert.Equal(t, int64(3), expected.ChannelAssignment["c7"].ServerID) + counts := countByServerID(expected) + assert.Equal(t, 3, len(counts)) + for _, count := range counts { + assert.GreaterOrEqual(t, count, 3) + assert.LessOrEqual(t, count, 4) + } + assert.NoError(t, err) + + assert.Equal(t, "pchannel_count_fair", policy.Name()) + expected, err = policy.Balance(balancer.CurrentLayout{ + IncomingChannels: []string{ + "c8", + "c9", + "c10", + }, + AllNodesInfo: map[int64]types.StreamingNodeInfo{ + 1: {ServerID: 1}, + 2: {ServerID: 2}, + 3: {ServerID: 3}, + }, + AssignedChannels: map[int64][]types.PChannelInfo{ + 1: {}, + 2: { + {Name: "c1"}, + {Name: "c4"}, + }, + 3: { + {Name: "c2"}, + {Name: "c3"}, + {Name: "c5"}, + {Name: "c6"}, + {Name: "c7"}, + }, + }, + ChannelsToNodes: map[string]int64{ + "c1": 2, + "c3": 3, + "c4": 2, + "c2": 3, + "c5": 3, + "c6": 3, + "c7": 3, + }, + }) + + assert.Equal(t, 10, len(expected.ChannelAssignment)) + assert.Equal(t, int64(2), expected.ChannelAssignment["c1"].ServerID) + assert.Equal(t, int64(2), expected.ChannelAssignment["c4"].ServerID) + counts = countByServerID(expected) + assert.Equal(t, 3, len(counts)) + for _, count := range counts { + assert.GreaterOrEqual(t, count, 3) + assert.LessOrEqual(t, count, 4) + } + assert.NoError(t, err) + + assert.Equal(t, "pchannel_count_fair", policy.Name()) + expected, err = policy.Balance(balancer.CurrentLayout{ + IncomingChannels: []string{ + "c10", + }, + AllNodesInfo: map[int64]types.StreamingNodeInfo{ + 1: {ServerID: 1}, + 2: {ServerID: 2}, + 3: {ServerID: 3}, + }, + AssignedChannels: map[int64][]types.PChannelInfo{ + 1: { + {Name: "c1"}, + {Name: "c2"}, + {Name: "c3"}, + }, + 2: { + {Name: "c4"}, + {Name: "c5"}, + {Name: "c6"}, + }, + 3: { + {Name: "c7"}, + {Name: "c8"}, + {Name: "c9"}, + }, + }, + ChannelsToNodes: map[string]int64{ + "c1": 1, + "c2": 1, + "c3": 1, + "c4": 2, + "c5": 2, + "c6": 2, + "c7": 3, + "c8": 3, + "c9": 3, + }, + }) + + assert.Equal(t, 10, len(expected.ChannelAssignment)) + assert.Equal(t, int64(1), expected.ChannelAssignment["c1"].ServerID) + assert.Equal(t, int64(1), expected.ChannelAssignment["c2"].ServerID) + assert.Equal(t, int64(1), expected.ChannelAssignment["c3"].ServerID) + assert.Equal(t, int64(2), expected.ChannelAssignment["c4"].ServerID) + assert.Equal(t, int64(2), expected.ChannelAssignment["c5"].ServerID) + assert.Equal(t, int64(2), expected.ChannelAssignment["c6"].ServerID) + assert.Equal(t, int64(3), expected.ChannelAssignment["c7"].ServerID) + assert.Equal(t, int64(3), expected.ChannelAssignment["c8"].ServerID) + assert.Equal(t, int64(3), expected.ChannelAssignment["c9"].ServerID) + counts = countByServerID(expected) + assert.Equal(t, 3, len(counts)) + for _, count := range counts { + assert.GreaterOrEqual(t, count, 3) + assert.LessOrEqual(t, count, 4) + } + assert.NoError(t, err) + + _, err = policy.Balance(balancer.CurrentLayout{}) + assert.Error(t, err) +} + +func countByServerID(expected balancer.ExpectedLayout) map[int64]int { + counts := make(map[int64]int) + for _, node := range expected.ChannelAssignment { + counts[node.ServerID]++ + } + return counts +} diff --git a/internal/streamingcoord/server/balancer/policy_registry.go b/internal/streamingcoord/server/balancer/policy_registry.go new file mode 100644 index 0000000000000..a198627cc2c59 --- /dev/null +++ b/internal/streamingcoord/server/balancer/policy_registry.go @@ -0,0 +1,65 @@ +package balancer + +import ( + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// policies is a map of registered balancer policies. +var policies typeutil.ConcurrentMap[string, Policy] + +// CurrentLayout is the full topology of streaming node and pChannel. +type CurrentLayout struct { + IncomingChannels []string // IncomingChannels is the channels that are waiting for assignment (not assigned in AllNodesInfo). + AllNodesInfo map[int64]types.StreamingNodeInfo // AllNodesInfo is the full information of all available streaming nodes and related pchannels (contain the node not assign anything on it). + AssignedChannels map[int64][]types.PChannelInfo // AssignedChannels maps the node id to assigned channels. + ChannelsToNodes map[string]int64 // ChannelsToNodes maps assigned channel name to node id. +} + +// TotalChannels returns the total number of channels in the layout. +func (layout *CurrentLayout) TotalChannels() int { + return len(layout.IncomingChannels) + len(layout.ChannelsToNodes) +} + +// TotalNodes returns the total number of nodes in the layout. +func (layout *CurrentLayout) TotalNodes() int { + return len(layout.AllNodesInfo) +} + +// ExpectedLayout is the expected layout of streaming node and pChannel. +type ExpectedLayout struct { + ChannelAssignment map[string]types.StreamingNodeInfo // ChannelAssignment is the assignment of channel to node. +} + +// Policy is a interface to define the policy of rebalance. +type Policy interface { + // Name is the name of the policy. + Name() string + + // Balance is a function to balance the load of streaming node. + // 1. all channel should be assigned. + // 2. incoming layout should not be changed. + // 3. return a expected layout. + // 4. otherwise, error must be returned. + // return a map of channel to a list of balance operation. + // All balance operation in a list will be executed in order. + // different channel's balance operation can be executed concurrently. + Balance(currentLayout CurrentLayout) (expectedLayout ExpectedLayout, err error) +} + +// RegisterPolicy registers balancer policy. +func RegisterPolicy(p Policy) { + _, loaded := policies.GetOrInsert(p.Name(), p) + if loaded { + panic("policy already registered: " + p.Name()) + } +} + +// mustGetPolicy returns the walimpls builder by name. +func mustGetPolicy(name string) Policy { + b, ok := policies.Get(name) + if !ok { + panic("policy not found: " + name) + } + return b +} diff --git a/internal/streamingcoord/server/balancer/request.go b/internal/streamingcoord/server/balancer/request.go new file mode 100644 index 0000000000000..2693cc40d9cc2 --- /dev/null +++ b/internal/streamingcoord/server/balancer/request.go @@ -0,0 +1,42 @@ +package balancer + +import ( + "context" + + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +// request is a operation request. +type request struct { + ctx context.Context + apply requestApply + future *syncutil.Future[error] +} + +// requestApply is a request operation to be executed. +type requestApply func(impl *balancerImpl) + +// newOpMarkAsUnavailable is a operation to mark some channels as unavailable. +func newOpMarkAsUnavailable(ctx context.Context, pChannels []types.PChannelInfo) *request { + future := syncutil.NewFuture[error]() + return &request{ + ctx: ctx, + apply: func(impl *balancerImpl) { + future.Set(impl.channelMetaManager.MarkAsUnavailable(ctx, pChannels)) + }, + future: future, + } +} + +// newOpTrigger is a operation to trigger a re-balance operation. +func newOpTrigger(ctx context.Context) *request { + future := syncutil.NewFuture[error]() + return &request{ + ctx: ctx, + apply: func(impl *balancerImpl) { + future.Set(nil) + }, + future: future, + } +} diff --git a/internal/streamingcoord/server/resource/resource.go b/internal/streamingcoord/server/resource/resource.go new file mode 100644 index 0000000000000..6dcf4e5c44a20 --- /dev/null +++ b/internal/streamingcoord/server/resource/resource.go @@ -0,0 +1,66 @@ +package resource + +import ( + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/milvus-io/milvus/internal/metastore" +) + +var r *resourceImpl // singleton resource instance + +// optResourceInit is the option to initialize the resource. +type optResourceInit func(r *resourceImpl) + +// OptETCD provides the etcd client to the resource. +func OptETCD(etcd *clientv3.Client) optResourceInit { + return func(r *resourceImpl) { + r.etcdClient = etcd + } +} + +// OptStreamingCatalog provides streaming catalog to the resource. +func OptStreamingCatalog(catalog metastore.StreamingCoordCataLog) optResourceInit { + return func(r *resourceImpl) { + r.streamingCatalog = catalog + } +} + +// Init initializes the singleton of resources. +// Should be call when streaming node startup. +func Init(opts ...optResourceInit) { + r = &resourceImpl{} + for _, opt := range opts { + opt(r) + } + assertNotNil(r.ETCD()) + assertNotNil(r.StreamingCatalog()) +} + +// Resource access the underlying singleton of resources. +func Resource() *resourceImpl { + return r +} + +// resourceImpl is a basic resource dependency for streamingnode server. +// All utility on it is concurrent-safe and singleton. +type resourceImpl struct { + etcdClient *clientv3.Client + streamingCatalog metastore.StreamingCoordCataLog +} + +// StreamingCatalog returns the StreamingCatalog client. +func (r *resourceImpl) StreamingCatalog() metastore.StreamingCoordCataLog { + return r.streamingCatalog +} + +// ETCD returns the etcd client. +func (r *resourceImpl) ETCD() *clientv3.Client { + return r.etcdClient +} + +// assertNotNil panics if the resource is nil. +func assertNotNil(v interface{}) { + if v == nil { + panic("nil resource") + } +} diff --git a/internal/streamingcoord/server/resource/resource_test.go b/internal/streamingcoord/server/resource/resource_test.go new file mode 100644 index 0000000000000..55a5879a08aff --- /dev/null +++ b/internal/streamingcoord/server/resource/resource_test.go @@ -0,0 +1,32 @@ +package resource + +import ( + "testing" + + "github.com/stretchr/testify/assert" + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/milvus-io/milvus/internal/mocks/mock_metastore" +) + +func TestInit(t *testing.T) { + assert.Panics(t, func() { + Init() + }) + assert.Panics(t, func() { + Init(OptETCD(&clientv3.Client{})) + }) + assert.Panics(t, func() { + Init(OptETCD(&clientv3.Client{})) + }) + Init(OptETCD(&clientv3.Client{}), OptStreamingCatalog( + mock_metastore.NewMockStreamingCoordCataLog(t), + )) + + assert.NotNil(t, Resource().StreamingCatalog()) + assert.NotNil(t, Resource().ETCD()) +} + +func TestInitForTest(t *testing.T) { + InitForTest() +} diff --git a/internal/streamingcoord/server/resource/test_utility.go b/internal/streamingcoord/server/resource/test_utility.go new file mode 100644 index 0000000000000..ec9833ff793bf --- /dev/null +++ b/internal/streamingcoord/server/resource/test_utility.go @@ -0,0 +1,12 @@ +//go:build test +// +build test + +package resource + +// InitForTest initializes the singleton of resources for test. +func InitForTest(opts ...optResourceInit) { + r = &resourceImpl{} + for _, opt := range opts { + opt(r) + } +} diff --git a/internal/streamingcoord/server/service/assignment.go b/internal/streamingcoord/server/service/assignment.go new file mode 100644 index 0000000000000..09a76d7cf8fc2 --- /dev/null +++ b/internal/streamingcoord/server/service/assignment.go @@ -0,0 +1,37 @@ +package service + +import ( + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/service/discover" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +var _ streamingpb.StreamingCoordAssignmentServiceServer = (*assignmentServiceImpl)(nil) + +// NewAssignmentService returns a new assignment service. +func NewAssignmentService( + balancer balancer.Balancer, +) streamingpb.StreamingCoordAssignmentServiceServer { + return &assignmentServiceImpl{ + balancer: balancer, + } +} + +type AssignmentService interface { + streamingpb.StreamingCoordAssignmentServiceServer +} + +// assignmentServiceImpl is the implementation of the assignment service. +type assignmentServiceImpl struct { + balancer balancer.Balancer +} + +// AssignmentDiscover watches the state of all log nodes. +func (s *assignmentServiceImpl) AssignmentDiscover(server streamingpb.StreamingCoordAssignmentService_AssignmentDiscoverServer) error { + metrics.StreamingCoordAssignmentListenerTotal.WithLabelValues(paramtable.GetStringNodeID()).Inc() + defer metrics.StreamingCoordAssignmentListenerTotal.WithLabelValues(paramtable.GetStringNodeID()).Dec() + + return discover.NewAssignmentDiscoverServer(s.balancer, server).Execute() +} diff --git a/internal/streamingcoord/server/service/discover/discover_grpc_server_helper.go b/internal/streamingcoord/server/service/discover/discover_grpc_server_helper.go new file mode 100644 index 0000000000000..02270755a5d06 --- /dev/null +++ b/internal/streamingcoord/server/service/discover/discover_grpc_server_helper.go @@ -0,0 +1,51 @@ +package discover + +import ( + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/util/streamingutil/typeconverter" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// discoverGrpcServerHelper is a wrapped discover server of log messages. +type discoverGrpcServerHelper struct { + streamingpb.StreamingCoordAssignmentService_AssignmentDiscoverServer +} + +// SendFullAssignment sends the full assignment to client. +func (h *discoverGrpcServerHelper) SendFullAssignment(v typeutil.VersionInt64Pair, relations []types.PChannelInfoAssigned) error { + assignmentsMap := make(map[int64]*streamingpb.StreamingNodeAssignment) + for _, relation := range relations { + if assignmentsMap[relation.Node.ServerID] == nil { + assignmentsMap[relation.Node.ServerID] = &streamingpb.StreamingNodeAssignment{ + Node: typeconverter.NewProtoFromStreamingNodeInfo(relation.Node), + Channels: make([]*streamingpb.PChannelInfo, 0), + } + } + assignmentsMap[relation.Node.ServerID].Channels = append( + assignmentsMap[relation.Node.ServerID].Channels, typeconverter.NewProtoFromPChannelInfo(relation.Channel)) + } + + assignments := make([]*streamingpb.StreamingNodeAssignment, 0, len(assignmentsMap)) + for _, node := range assignmentsMap { + assignments = append(assignments, node) + } + return h.Send(&streamingpb.AssignmentDiscoverResponse{ + Response: &streamingpb.AssignmentDiscoverResponse_FullAssignment{ + FullAssignment: &streamingpb.FullStreamingNodeAssignmentWithVersion{ + Version: &streamingpb.VersionPair{ + Global: v.Global, + Local: v.Local, + }, + Assignments: assignments, + }, + }, + }) +} + +// SendCloseResponse sends the close response to client. +func (h *discoverGrpcServerHelper) SendCloseResponse() error { + return h.Send(&streamingpb.AssignmentDiscoverResponse{ + Response: &streamingpb.AssignmentDiscoverResponse_Close{}, + }) +} diff --git a/internal/streamingcoord/server/service/discover/discover_server.go b/internal/streamingcoord/server/service/discover/discover_server.go new file mode 100644 index 0000000000000..ff08092f3909d --- /dev/null +++ b/internal/streamingcoord/server/service/discover/discover_server.go @@ -0,0 +1,98 @@ +package discover + +import ( + "context" + "io" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/util/streamingutil/typeconverter" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +var errClosedByUser = errors.New("closed by user") + +func NewAssignmentDiscoverServer( + balancer balancer.Balancer, + streamServer streamingpb.StreamingCoordAssignmentService_AssignmentDiscoverServer, +) *AssignmentDiscoverServer { + ctx, cancel := context.WithCancelCause(streamServer.Context()) + return &AssignmentDiscoverServer{ + ctx: ctx, + cancel: cancel, + balancer: balancer, + streamServer: discoverGrpcServerHelper{ + streamServer, + }, + logger: log.With(), + } +} + +type AssignmentDiscoverServer struct { + ctx context.Context + cancel context.CancelCauseFunc + balancer balancer.Balancer + streamServer discoverGrpcServerHelper + logger *log.MLogger +} + +func (s *AssignmentDiscoverServer) Execute() error { + // Start a recv arm to handle the control message from client. + go func() { + // recv loop will be blocked until the stream is closed. + // 1. close by client. + // 2. close by server context cancel by return of outside Execute. + _ = s.recvLoop() + }() + + // Start a send loop on current main goroutine. + // the loop will be blocked until: + // 1. the stream is broken. + // 2. recv arm recv closed and all response is sent. + return s.sendLoop() +} + +// recvLoop receives the message from client. +func (s *AssignmentDiscoverServer) recvLoop() (err error) { + defer func() { + if err != nil { + s.cancel(err) + s.logger.Warn("recv arm of stream closed by unexpected error", zap.Error(err)) + return + } + s.cancel(errClosedByUser) + s.logger.Info("recv arm of stream closed") + }() + + for { + req, err := s.streamServer.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + switch req := req.Command.(type) { + case *streamingpb.AssignmentDiscoverRequest_ReportError: + channel := typeconverter.NewPChannelInfoFromProto(req.ReportError.GetPchannel()) + // mark the channel as unavailable and trigger a recover right away. + s.balancer.MarkAsUnavailable(s.ctx, []types.PChannelInfo{channel}) + case *streamingpb.AssignmentDiscoverRequest_Close: + default: + s.logger.Warn("unknown command type", zap.Any("command", req)) + } + } +} + +// sendLoop sends the message to client. +func (s *AssignmentDiscoverServer) sendLoop() error { + err := s.balancer.WatchBalanceResult(s.ctx, s.streamServer.SendFullAssignment) + if errors.Is(err, errClosedByUser) { + return s.streamServer.SendCloseResponse() + } + return err +} diff --git a/internal/streamingcoord/server/service/discover/discover_server_test.go b/internal/streamingcoord/server/service/discover/discover_server_test.go new file mode 100644 index 0000000000000..6f35309c51a03 --- /dev/null +++ b/internal/streamingcoord/server/service/discover/discover_server_test.go @@ -0,0 +1,82 @@ +package discover + +import ( + "context" + "io" + "testing" + + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus/internal/mocks/proto/mock_streamingpb" + "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestAssignmentDiscover(t *testing.T) { + b := mock_balancer.NewMockBalancer(t) + b.EXPECT().WatchBalanceResult(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb func(typeutil.VersionInt64Pair, []types.PChannelInfoAssigned) error) error { + versions := []typeutil.VersionInt64Pair{ + {Global: 1, Local: 2}, + {Global: 1, Local: 3}, + } + pchans := [][]types.PChannelInfoAssigned{ + { + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, + }, + }, + { + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, + }, + types.PChannelInfoAssigned{ + Channel: types.PChannelInfo{Name: "pchannel2", Term: 1}, + Node: types.StreamingNodeInfo{ServerID: 1, Address: "localhost:1"}, + }, + }, + } + for i := 0; i < len(versions); i++ { + cb(versions[i], pchans[i]) + } + <-ctx.Done() + return context.Cause(ctx) + }) + b.EXPECT().MarkAsUnavailable(mock.Anything, mock.Anything).Return(nil) + + streamServer := mock_streamingpb.NewMockStreamingCoordAssignmentService_AssignmentDiscoverServer(t) + streamServer.EXPECT().Context().Return(context.Background()) + k := 0 + reqs := []*streamingpb.AssignmentDiscoverRequest{ + { + Command: &streamingpb.AssignmentDiscoverRequest_ReportError{ + ReportError: &streamingpb.ReportAssignmentErrorRequest{ + Pchannel: &streamingpb.PChannelInfo{ + Name: "pchannel", + Term: 1, + }, + Err: &streamingpb.StreamingError{ + Code: streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_EXIST, + }, + }, + }, + }, + { + Command: &streamingpb.AssignmentDiscoverRequest_Close{}, + }, + } + streamServer.EXPECT().Recv().RunAndReturn(func() (*streamingpb.AssignmentDiscoverRequest, error) { + if k >= len(reqs) { + return nil, io.EOF + } + req := reqs[k] + k++ + return req, nil + }) + streamServer.EXPECT().Send(mock.Anything).Return(nil) + ads := NewAssignmentDiscoverServer(b, streamServer) + ads.Execute() +} diff --git a/internal/streamingnode/client/manager/manager.go b/internal/streamingnode/client/manager/manager.go new file mode 100644 index 0000000000000..5bb2f55c6b2d3 --- /dev/null +++ b/internal/streamingnode/client/manager/manager.go @@ -0,0 +1,25 @@ +package manager + +import ( + "context" + + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +type ManagerClient interface { + // WatchNodeChanged returns a channel that receive a node change. + WatchNodeChanged(ctx context.Context) <-chan map[int64]*sessionutil.SessionRaw + + // CollectStatus collects status of all wal instances in all streamingnode. + CollectAllStatus(ctx context.Context) (map[int64]types.StreamingNodeStatus, error) + + // Assign a wal instance for the channel on log node of given server id. + Assign(ctx context.Context, pchannel types.PChannelInfoAssigned) error + + // Remove the wal instance for the channel on log node of given server id. + Remove(ctx context.Context, pchannel types.PChannelInfoAssigned) error + + // Close closes the manager client. + Close() +} diff --git a/internal/streamingservice/.mockery.yaml b/internal/streamingservice/.mockery.yaml index 9c592d044b846..8628b02263300 100644 --- a/internal/streamingservice/.mockery.yaml +++ b/internal/streamingservice/.mockery.yaml @@ -1,10 +1,16 @@ quiet: False with-expecter: True filename: "mock_{{.InterfaceName}}.go" -dir: "internal/mocks/{{trimPrefix .PackagePath \"github.com/milvus-io/milvus/internal\" | dir }}/mock_{{.PackageName}}" +dir: 'internal/mocks/{{trimPrefix .PackagePath "github.com/milvus-io/milvus/internal" | dir }}/mock_{{.PackageName}}' mockname: "Mock{{.InterfaceName}}" outpkg: "mock_{{.PackageName}}" packages: + github.com/milvus-io/milvus/internal/streamingcoord/server/balancer: + interfaces: + Balancer: + github.com/milvus-io/milvus/internal/streamingnode/client/manager: + interfaces: + ManagerClient: github.com/milvus-io/milvus/internal/streamingnode/server/wal: interfaces: OpenerBuilder: @@ -23,6 +29,10 @@ packages: interfaces: StreamingNodeHandlerService_ConsumeServer: StreamingNodeHandlerService_ProduceServer: + StreamingCoordAssignmentService_AssignmentDiscoverServer: github.com/milvus-io/milvus/internal/streamingnode/server/walmanager: interfaces: Manager: + github.com/milvus-io/milvus/internal/metastore: + interfaces: + StreamingCoordCataLog: diff --git a/internal/util/streamingutil/typeconverter/streaming_node.go b/internal/util/streamingutil/typeconverter/streaming_node.go new file mode 100644 index 0000000000000..62498acbbdd6a --- /dev/null +++ b/internal/util/streamingutil/typeconverter/streaming_node.go @@ -0,0 +1,20 @@ +package typeconverter + +import ( + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +func NewStreamingNodeInfoFromProto(proto *streamingpb.StreamingNodeInfo) types.StreamingNodeInfo { + return types.StreamingNodeInfo{ + ServerID: proto.ServerId, + Address: proto.Address, + } +} + +func NewProtoFromStreamingNodeInfo(info types.StreamingNodeInfo) *streamingpb.StreamingNodeInfo { + return &streamingpb.StreamingNodeInfo{ + ServerId: info.ServerID, + Address: info.Address, + } +} diff --git a/internal/util/streamingutil/util/topic.go b/internal/util/streamingutil/util/topic.go new file mode 100644 index 0000000000000..8020e2a1bed9d --- /dev/null +++ b/internal/util/streamingutil/util/topic.go @@ -0,0 +1,30 @@ +package util + +import ( + "fmt" + + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// GetAllTopicsFromConfiguration gets all topics from configuration. +// It's a utility function to fetch all topics from configuration. +func GetAllTopicsFromConfiguration() typeutil.Set[string] { + var channels typeutil.Set[string] + if paramtable.Get().CommonCfg.PreCreatedTopicEnabled.GetAsBool() { + channels = typeutil.NewSet[string](paramtable.Get().CommonCfg.TopicNames.GetAsStrings()...) + } else { + channels = genChannelNames(paramtable.Get().CommonCfg.RootCoordDml.GetValue(), paramtable.Get().RootCoordCfg.DmlChannelNum.GetAsInt()) + } + return channels +} + +// genChannelNames generates channel names with prefix and number. +func genChannelNames(prefix string, num int) typeutil.Set[string] { + results := typeutil.NewSet[string]() + for idx := 0; idx < num; idx++ { + result := fmt.Sprintf("%s_%d", prefix, idx) + results.Insert(result) + } + return results +} diff --git a/internal/util/streamingutil/util/topic_test.go b/internal/util/streamingutil/util/topic_test.go new file mode 100644 index 0000000000000..bdce25066b1f5 --- /dev/null +++ b/internal/util/streamingutil/util/topic_test.go @@ -0,0 +1,19 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestGetAllTopicsFromConfiguration(t *testing.T) { + paramtable.Init() + topics := GetAllTopicsFromConfiguration() + assert.Len(t, topics, 16) + paramtable.Get().CommonCfg.PreCreatedTopicEnabled.SwapTempValue("true") + paramtable.Get().CommonCfg.TopicNames.SwapTempValue("topic1,topic2,topic3") + topics = GetAllTopicsFromConfiguration() + assert.Len(t, topics, 3) +} diff --git a/pkg/streaming/.mockery.yaml b/pkg/.mockery_pkg.yaml similarity index 90% rename from pkg/streaming/.mockery.yaml rename to pkg/.mockery_pkg.yaml index 21b67cc759011..158f9709757c0 100644 --- a/pkg/streaming/.mockery.yaml +++ b/pkg/.mockery_pkg.yaml @@ -5,6 +5,9 @@ dir: "mocks/{{trimPrefix .PackagePath \"github.com/milvus-io/milvus/pkg\" | dir mockname: "Mock{{.InterfaceName}}" outpkg: "mock_{{.PackageName}}" packages: + github.com/milvus-io/milvus/pkg/kv: + interfaces: + MetaKv: github.com/milvus-io/milvus/pkg/streaming/util/message: interfaces: MessageID: diff --git a/pkg/Makefile b/pkg/Makefile index eb3af3a7b1957..cb09dd830dfde 100644 --- a/pkg/Makefile +++ b/pkg/Makefile @@ -11,12 +11,10 @@ INSTALL_PATH := $(ROOTPATH)/bin getdeps: $(MAKE) -C $(ROOTPATH) getdeps -generate-mockery: getdeps generate-mockery-streaming +generate-mockery: getdeps + $(INSTALL_PATH)/mockery --config $(PWD)/.mockery_pkg.yaml $(INSTALL_PATH)/mockery --name=MsgStream --dir=$(PWD)/mq/msgstream --output=$(PWD)/mq/msgstream --filename=mock_msgstream.go --with-expecter --structname=MockMsgStream --outpkg=msgstream --inpackage $(INSTALL_PATH)/mockery --name=Factory --dir=$(PWD)/mq/msgstream --output=$(PWD)/mq/msgstream --filename=mock_msgstream_factory.go --with-expecter --structname=MockFactory --outpkg=msgstream --inpackage $(INSTALL_PATH)/mockery --name=Client --dir=$(PWD)/mq/msgdispatcher --output=$(PWD)/mq/msgsdispatcher --filename=mock_client.go --with-expecter --structname=MockClient --outpkg=msgdispatcher --inpackage $(INSTALL_PATH)/mockery --name=Logger --dir=$(PWD)/eventlog --output=$(PWD)/eventlog --filename=mock_logger.go --with-expecter --structname=MockLogger --outpkg=eventlog --inpackage - $(INSTALL_PATH)/mockery --name=MessageID --dir=$(PWD)/mq/msgstream/mqwrapper --output=$(PWD)/mq/msgstream/mqwrapper --filename=mock_id.go --with-expecter --structname=MockMessageID --outpkg=mqwrapper --inpackage - -generate-mockery-streaming: getdeps - $(INSTALL_PATH)/mockery --config $(PWD)/streaming/.mockery.yaml + $(INSTALL_PATH)/mockery --name=MessageID --dir=$(PWD)/mq/msgstream/mqwrapper --output=$(PWD)/mq/msgstream/mqwrapper --filename=mock_id.go --with-expecter --structname=MockMessageID --outpkg=mqwrapper --inpackage \ No newline at end of file diff --git a/pkg/metrics/streaming_service_metrics.go b/pkg/metrics/streaming_service_metrics.go index 9e2899d7f8ca4..2275f9a142fe3 100644 --- a/pkg/metrics/streaming_service_metrics.go +++ b/pkg/metrics/streaming_service_metrics.go @@ -65,10 +65,10 @@ var ( Help: "Total of assignment listener", }) - StreamingCoordAssignmentInfo = newStreamingCoordGaugeVec(prometheus.GaugeOpts{ + StreamingCoordAssignmentVersion = newStreamingCoordGaugeVec(prometheus.GaugeOpts{ Name: "assignment_info", Help: "Info of assignment", - }, "global_version", "local_version") + }) // StreamingNode metrics StreamingNodeWALTotal = newStreamingNodeGaugeVec(prometheus.GaugeOpts{ @@ -119,7 +119,7 @@ func RegisterStreamingServiceClient(registry *prometheus.Registry) { func RegisterStreamingCoord(registry *prometheus.Registry) { registry.MustRegister(StreamingCoordPChannelTotal) registry.MustRegister(StreamingCoordAssignmentListenerTotal) - registry.MustRegister(StreamingCoordAssignmentInfo) + registry.MustRegister(StreamingCoordAssignmentVersion) } // RegisterStreamingNode registers log service metrics diff --git a/pkg/mocks/mock_kv/mock_MetaKv.go b/pkg/mocks/mock_kv/mock_MetaKv.go new file mode 100644 index 0000000000000..5744be0fc7048 --- /dev/null +++ b/pkg/mocks/mock_kv/mock_MetaKv.go @@ -0,0 +1,807 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_kv + +import ( + predicates "github.com/milvus-io/milvus/pkg/kv/predicates" + mock "github.com/stretchr/testify/mock" +) + +// MockMetaKv is an autogenerated mock type for the MetaKv type +type MockMetaKv struct { + mock.Mock +} + +type MockMetaKv_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMetaKv) EXPECT() *MockMetaKv_Expecter { + return &MockMetaKv_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockMetaKv) Close() { + _m.Called() +} + +// MockMetaKv_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockMetaKv_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockMetaKv_Expecter) Close() *MockMetaKv_Close_Call { + return &MockMetaKv_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockMetaKv_Close_Call) Run(run func()) *MockMetaKv_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMetaKv_Close_Call) Return() *MockMetaKv_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockMetaKv_Close_Call) RunAndReturn(run func()) *MockMetaKv_Close_Call { + _c.Call.Return(run) + return _c +} + +// CompareVersionAndSwap provides a mock function with given fields: key, version, target +func (_m *MockMetaKv) CompareVersionAndSwap(key string, version int64, target string) (bool, error) { + ret := _m.Called(key, version, target) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string, int64, string) (bool, error)); ok { + return rf(key, version, target) + } + if rf, ok := ret.Get(0).(func(string, int64, string) bool); ok { + r0 = rf(key, version, target) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string, int64, string) error); ok { + r1 = rf(key, version, target) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMetaKv_CompareVersionAndSwap_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CompareVersionAndSwap' +type MockMetaKv_CompareVersionAndSwap_Call struct { + *mock.Call +} + +// CompareVersionAndSwap is a helper method to define mock.On call +// - key string +// - version int64 +// - target string +func (_e *MockMetaKv_Expecter) CompareVersionAndSwap(key interface{}, version interface{}, target interface{}) *MockMetaKv_CompareVersionAndSwap_Call { + return &MockMetaKv_CompareVersionAndSwap_Call{Call: _e.mock.On("CompareVersionAndSwap", key, version, target)} +} + +func (_c *MockMetaKv_CompareVersionAndSwap_Call) Run(run func(key string, version int64, target string)) *MockMetaKv_CompareVersionAndSwap_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(int64), args[2].(string)) + }) + return _c +} + +func (_c *MockMetaKv_CompareVersionAndSwap_Call) Return(_a0 bool, _a1 error) *MockMetaKv_CompareVersionAndSwap_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMetaKv_CompareVersionAndSwap_Call) RunAndReturn(run func(string, int64, string) (bool, error)) *MockMetaKv_CompareVersionAndSwap_Call { + _c.Call.Return(run) + return _c +} + +// GetPath provides a mock function with given fields: key +func (_m *MockMetaKv) GetPath(key string) string { + ret := _m.Called(key) + + var r0 string + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(key) + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockMetaKv_GetPath_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPath' +type MockMetaKv_GetPath_Call struct { + *mock.Call +} + +// GetPath is a helper method to define mock.On call +// - key string +func (_e *MockMetaKv_Expecter) GetPath(key interface{}) *MockMetaKv_GetPath_Call { + return &MockMetaKv_GetPath_Call{Call: _e.mock.On("GetPath", key)} +} + +func (_c *MockMetaKv_GetPath_Call) Run(run func(key string)) *MockMetaKv_GetPath_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMetaKv_GetPath_Call) Return(_a0 string) *MockMetaKv_GetPath_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMetaKv_GetPath_Call) RunAndReturn(run func(string) string) *MockMetaKv_GetPath_Call { + _c.Call.Return(run) + return _c +} + +// Has provides a mock function with given fields: key +func (_m *MockMetaKv) Has(key string) (bool, error) { + ret := _m.Called(key) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { + return rf(key) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(key) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(key) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMetaKv_Has_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Has' +type MockMetaKv_Has_Call struct { + *mock.Call +} + +// Has is a helper method to define mock.On call +// - key string +func (_e *MockMetaKv_Expecter) Has(key interface{}) *MockMetaKv_Has_Call { + return &MockMetaKv_Has_Call{Call: _e.mock.On("Has", key)} +} + +func (_c *MockMetaKv_Has_Call) Run(run func(key string)) *MockMetaKv_Has_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMetaKv_Has_Call) Return(_a0 bool, _a1 error) *MockMetaKv_Has_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMetaKv_Has_Call) RunAndReturn(run func(string) (bool, error)) *MockMetaKv_Has_Call { + _c.Call.Return(run) + return _c +} + +// HasPrefix provides a mock function with given fields: prefix +func (_m *MockMetaKv) HasPrefix(prefix string) (bool, error) { + ret := _m.Called(prefix) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { + return rf(prefix) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(prefix) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(prefix) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMetaKv_HasPrefix_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasPrefix' +type MockMetaKv_HasPrefix_Call struct { + *mock.Call +} + +// HasPrefix is a helper method to define mock.On call +// - prefix string +func (_e *MockMetaKv_Expecter) HasPrefix(prefix interface{}) *MockMetaKv_HasPrefix_Call { + return &MockMetaKv_HasPrefix_Call{Call: _e.mock.On("HasPrefix", prefix)} +} + +func (_c *MockMetaKv_HasPrefix_Call) Run(run func(prefix string)) *MockMetaKv_HasPrefix_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMetaKv_HasPrefix_Call) Return(_a0 bool, _a1 error) *MockMetaKv_HasPrefix_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMetaKv_HasPrefix_Call) RunAndReturn(run func(string) (bool, error)) *MockMetaKv_HasPrefix_Call { + _c.Call.Return(run) + return _c +} + +// Load provides a mock function with given fields: key +func (_m *MockMetaKv) Load(key string) (string, error) { + ret := _m.Called(key) + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(string) (string, error)); ok { + return rf(key) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(key) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(key) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMetaKv_Load_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Load' +type MockMetaKv_Load_Call struct { + *mock.Call +} + +// Load is a helper method to define mock.On call +// - key string +func (_e *MockMetaKv_Expecter) Load(key interface{}) *MockMetaKv_Load_Call { + return &MockMetaKv_Load_Call{Call: _e.mock.On("Load", key)} +} + +func (_c *MockMetaKv_Load_Call) Run(run func(key string)) *MockMetaKv_Load_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMetaKv_Load_Call) Return(_a0 string, _a1 error) *MockMetaKv_Load_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMetaKv_Load_Call) RunAndReturn(run func(string) (string, error)) *MockMetaKv_Load_Call { + _c.Call.Return(run) + return _c +} + +// LoadWithPrefix provides a mock function with given fields: key +func (_m *MockMetaKv) LoadWithPrefix(key string) ([]string, []string, error) { + ret := _m.Called(key) + + var r0 []string + var r1 []string + var r2 error + if rf, ok := ret.Get(0).(func(string) ([]string, []string, error)); ok { + return rf(key) + } + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(key) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(string) []string); ok { + r1 = rf(key) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]string) + } + } + + if rf, ok := ret.Get(2).(func(string) error); ok { + r2 = rf(key) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockMetaKv_LoadWithPrefix_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadWithPrefix' +type MockMetaKv_LoadWithPrefix_Call struct { + *mock.Call +} + +// LoadWithPrefix is a helper method to define mock.On call +// - key string +func (_e *MockMetaKv_Expecter) LoadWithPrefix(key interface{}) *MockMetaKv_LoadWithPrefix_Call { + return &MockMetaKv_LoadWithPrefix_Call{Call: _e.mock.On("LoadWithPrefix", key)} +} + +func (_c *MockMetaKv_LoadWithPrefix_Call) Run(run func(key string)) *MockMetaKv_LoadWithPrefix_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMetaKv_LoadWithPrefix_Call) Return(_a0 []string, _a1 []string, _a2 error) *MockMetaKv_LoadWithPrefix_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockMetaKv_LoadWithPrefix_Call) RunAndReturn(run func(string) ([]string, []string, error)) *MockMetaKv_LoadWithPrefix_Call { + _c.Call.Return(run) + return _c +} + +// MultiLoad provides a mock function with given fields: keys +func (_m *MockMetaKv) MultiLoad(keys []string) ([]string, error) { + ret := _m.Called(keys) + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func([]string) ([]string, error)); ok { + return rf(keys) + } + if rf, ok := ret.Get(0).(func([]string) []string); ok { + r0 = rf(keys) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func([]string) error); ok { + r1 = rf(keys) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMetaKv_MultiLoad_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MultiLoad' +type MockMetaKv_MultiLoad_Call struct { + *mock.Call +} + +// MultiLoad is a helper method to define mock.On call +// - keys []string +func (_e *MockMetaKv_Expecter) MultiLoad(keys interface{}) *MockMetaKv_MultiLoad_Call { + return &MockMetaKv_MultiLoad_Call{Call: _e.mock.On("MultiLoad", keys)} +} + +func (_c *MockMetaKv_MultiLoad_Call) Run(run func(keys []string)) *MockMetaKv_MultiLoad_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]string)) + }) + return _c +} + +func (_c *MockMetaKv_MultiLoad_Call) Return(_a0 []string, _a1 error) *MockMetaKv_MultiLoad_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMetaKv_MultiLoad_Call) RunAndReturn(run func([]string) ([]string, error)) *MockMetaKv_MultiLoad_Call { + _c.Call.Return(run) + return _c +} + +// MultiRemove provides a mock function with given fields: keys +func (_m *MockMetaKv) MultiRemove(keys []string) error { + ret := _m.Called(keys) + + var r0 error + if rf, ok := ret.Get(0).(func([]string) error); ok { + r0 = rf(keys) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMetaKv_MultiRemove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MultiRemove' +type MockMetaKv_MultiRemove_Call struct { + *mock.Call +} + +// MultiRemove is a helper method to define mock.On call +// - keys []string +func (_e *MockMetaKv_Expecter) MultiRemove(keys interface{}) *MockMetaKv_MultiRemove_Call { + return &MockMetaKv_MultiRemove_Call{Call: _e.mock.On("MultiRemove", keys)} +} + +func (_c *MockMetaKv_MultiRemove_Call) Run(run func(keys []string)) *MockMetaKv_MultiRemove_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]string)) + }) + return _c +} + +func (_c *MockMetaKv_MultiRemove_Call) Return(_a0 error) *MockMetaKv_MultiRemove_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMetaKv_MultiRemove_Call) RunAndReturn(run func([]string) error) *MockMetaKv_MultiRemove_Call { + _c.Call.Return(run) + return _c +} + +// MultiSave provides a mock function with given fields: kvs +func (_m *MockMetaKv) MultiSave(kvs map[string]string) error { + ret := _m.Called(kvs) + + var r0 error + if rf, ok := ret.Get(0).(func(map[string]string) error); ok { + r0 = rf(kvs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMetaKv_MultiSave_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MultiSave' +type MockMetaKv_MultiSave_Call struct { + *mock.Call +} + +// MultiSave is a helper method to define mock.On call +// - kvs map[string]string +func (_e *MockMetaKv_Expecter) MultiSave(kvs interface{}) *MockMetaKv_MultiSave_Call { + return &MockMetaKv_MultiSave_Call{Call: _e.mock.On("MultiSave", kvs)} +} + +func (_c *MockMetaKv_MultiSave_Call) Run(run func(kvs map[string]string)) *MockMetaKv_MultiSave_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(map[string]string)) + }) + return _c +} + +func (_c *MockMetaKv_MultiSave_Call) Return(_a0 error) *MockMetaKv_MultiSave_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMetaKv_MultiSave_Call) RunAndReturn(run func(map[string]string) error) *MockMetaKv_MultiSave_Call { + _c.Call.Return(run) + return _c +} + +// MultiSaveAndRemove provides a mock function with given fields: saves, removals, preds +func (_m *MockMetaKv) MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + _va := make([]interface{}, len(preds)) + for _i := range preds { + _va[_i] = preds[_i] + } + var _ca []interface{} + _ca = append(_ca, saves, removals) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(map[string]string, []string, ...predicates.Predicate) error); ok { + r0 = rf(saves, removals, preds...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMetaKv_MultiSaveAndRemove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MultiSaveAndRemove' +type MockMetaKv_MultiSaveAndRemove_Call struct { + *mock.Call +} + +// MultiSaveAndRemove is a helper method to define mock.On call +// - saves map[string]string +// - removals []string +// - preds ...predicates.Predicate +func (_e *MockMetaKv_Expecter) MultiSaveAndRemove(saves interface{}, removals interface{}, preds ...interface{}) *MockMetaKv_MultiSaveAndRemove_Call { + return &MockMetaKv_MultiSaveAndRemove_Call{Call: _e.mock.On("MultiSaveAndRemove", + append([]interface{}{saves, removals}, preds...)...)} +} + +func (_c *MockMetaKv_MultiSaveAndRemove_Call) Run(run func(saves map[string]string, removals []string, preds ...predicates.Predicate)) *MockMetaKv_MultiSaveAndRemove_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]predicates.Predicate, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(predicates.Predicate) + } + } + run(args[0].(map[string]string), args[1].([]string), variadicArgs...) + }) + return _c +} + +func (_c *MockMetaKv_MultiSaveAndRemove_Call) Return(_a0 error) *MockMetaKv_MultiSaveAndRemove_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMetaKv_MultiSaveAndRemove_Call) RunAndReturn(run func(map[string]string, []string, ...predicates.Predicate) error) *MockMetaKv_MultiSaveAndRemove_Call { + _c.Call.Return(run) + return _c +} + +// MultiSaveAndRemoveWithPrefix provides a mock function with given fields: saves, removals, preds +func (_m *MockMetaKv) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + _va := make([]interface{}, len(preds)) + for _i := range preds { + _va[_i] = preds[_i] + } + var _ca []interface{} + _ca = append(_ca, saves, removals) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(map[string]string, []string, ...predicates.Predicate) error); ok { + r0 = rf(saves, removals, preds...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMetaKv_MultiSaveAndRemoveWithPrefix_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MultiSaveAndRemoveWithPrefix' +type MockMetaKv_MultiSaveAndRemoveWithPrefix_Call struct { + *mock.Call +} + +// MultiSaveAndRemoveWithPrefix is a helper method to define mock.On call +// - saves map[string]string +// - removals []string +// - preds ...predicates.Predicate +func (_e *MockMetaKv_Expecter) MultiSaveAndRemoveWithPrefix(saves interface{}, removals interface{}, preds ...interface{}) *MockMetaKv_MultiSaveAndRemoveWithPrefix_Call { + return &MockMetaKv_MultiSaveAndRemoveWithPrefix_Call{Call: _e.mock.On("MultiSaveAndRemoveWithPrefix", + append([]interface{}{saves, removals}, preds...)...)} +} + +func (_c *MockMetaKv_MultiSaveAndRemoveWithPrefix_Call) Run(run func(saves map[string]string, removals []string, preds ...predicates.Predicate)) *MockMetaKv_MultiSaveAndRemoveWithPrefix_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]predicates.Predicate, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(predicates.Predicate) + } + } + run(args[0].(map[string]string), args[1].([]string), variadicArgs...) + }) + return _c +} + +func (_c *MockMetaKv_MultiSaveAndRemoveWithPrefix_Call) Return(_a0 error) *MockMetaKv_MultiSaveAndRemoveWithPrefix_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMetaKv_MultiSaveAndRemoveWithPrefix_Call) RunAndReturn(run func(map[string]string, []string, ...predicates.Predicate) error) *MockMetaKv_MultiSaveAndRemoveWithPrefix_Call { + _c.Call.Return(run) + return _c +} + +// Remove provides a mock function with given fields: key +func (_m *MockMetaKv) Remove(key string) error { + ret := _m.Called(key) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(key) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMetaKv_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove' +type MockMetaKv_Remove_Call struct { + *mock.Call +} + +// Remove is a helper method to define mock.On call +// - key string +func (_e *MockMetaKv_Expecter) Remove(key interface{}) *MockMetaKv_Remove_Call { + return &MockMetaKv_Remove_Call{Call: _e.mock.On("Remove", key)} +} + +func (_c *MockMetaKv_Remove_Call) Run(run func(key string)) *MockMetaKv_Remove_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMetaKv_Remove_Call) Return(_a0 error) *MockMetaKv_Remove_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMetaKv_Remove_Call) RunAndReturn(run func(string) error) *MockMetaKv_Remove_Call { + _c.Call.Return(run) + return _c +} + +// RemoveWithPrefix provides a mock function with given fields: key +func (_m *MockMetaKv) RemoveWithPrefix(key string) error { + ret := _m.Called(key) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(key) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMetaKv_RemoveWithPrefix_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveWithPrefix' +type MockMetaKv_RemoveWithPrefix_Call struct { + *mock.Call +} + +// RemoveWithPrefix is a helper method to define mock.On call +// - key string +func (_e *MockMetaKv_Expecter) RemoveWithPrefix(key interface{}) *MockMetaKv_RemoveWithPrefix_Call { + return &MockMetaKv_RemoveWithPrefix_Call{Call: _e.mock.On("RemoveWithPrefix", key)} +} + +func (_c *MockMetaKv_RemoveWithPrefix_Call) Run(run func(key string)) *MockMetaKv_RemoveWithPrefix_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMetaKv_RemoveWithPrefix_Call) Return(_a0 error) *MockMetaKv_RemoveWithPrefix_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMetaKv_RemoveWithPrefix_Call) RunAndReturn(run func(string) error) *MockMetaKv_RemoveWithPrefix_Call { + _c.Call.Return(run) + return _c +} + +// Save provides a mock function with given fields: key, value +func (_m *MockMetaKv) Save(key string, value string) error { + ret := _m.Called(key, value) + + var r0 error + if rf, ok := ret.Get(0).(func(string, string) error); ok { + r0 = rf(key, value) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMetaKv_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save' +type MockMetaKv_Save_Call struct { + *mock.Call +} + +// Save is a helper method to define mock.On call +// - key string +// - value string +func (_e *MockMetaKv_Expecter) Save(key interface{}, value interface{}) *MockMetaKv_Save_Call { + return &MockMetaKv_Save_Call{Call: _e.mock.On("Save", key, value)} +} + +func (_c *MockMetaKv_Save_Call) Run(run func(key string, value string)) *MockMetaKv_Save_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(string)) + }) + return _c +} + +func (_c *MockMetaKv_Save_Call) Return(_a0 error) *MockMetaKv_Save_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMetaKv_Save_Call) RunAndReturn(run func(string, string) error) *MockMetaKv_Save_Call { + _c.Call.Return(run) + return _c +} + +// WalkWithPrefix provides a mock function with given fields: prefix, paginationSize, fn +func (_m *MockMetaKv) WalkWithPrefix(prefix string, paginationSize int, fn func([]byte, []byte) error) error { + ret := _m.Called(prefix, paginationSize, fn) + + var r0 error + if rf, ok := ret.Get(0).(func(string, int, func([]byte, []byte) error) error); ok { + r0 = rf(prefix, paginationSize, fn) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMetaKv_WalkWithPrefix_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WalkWithPrefix' +type MockMetaKv_WalkWithPrefix_Call struct { + *mock.Call +} + +// WalkWithPrefix is a helper method to define mock.On call +// - prefix string +// - paginationSize int +// - fn func([]byte , []byte) error +func (_e *MockMetaKv_Expecter) WalkWithPrefix(prefix interface{}, paginationSize interface{}, fn interface{}) *MockMetaKv_WalkWithPrefix_Call { + return &MockMetaKv_WalkWithPrefix_Call{Call: _e.mock.On("WalkWithPrefix", prefix, paginationSize, fn)} +} + +func (_c *MockMetaKv_WalkWithPrefix_Call) Run(run func(prefix string, paginationSize int, fn func([]byte, []byte) error)) *MockMetaKv_WalkWithPrefix_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(int), args[2].(func([]byte, []byte) error)) + }) + return _c +} + +func (_c *MockMetaKv_WalkWithPrefix_Call) Return(_a0 error) *MockMetaKv_WalkWithPrefix_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMetaKv_WalkWithPrefix_Call) RunAndReturn(run func(string, int, func([]byte, []byte) error) error) *MockMetaKv_WalkWithPrefix_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMetaKv creates a new instance of MockMetaKv. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMetaKv(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMetaKv { + mock := &MockMetaKv{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/mq/msgdispatcher/mock_client.go b/pkg/mq/msgdispatcher/mock_client.go index af688327ea291..883a16de98f58 100644 --- a/pkg/mq/msgdispatcher/mock_client.go +++ b/pkg/mq/msgdispatcher/mock_client.go @@ -3,14 +3,15 @@ package msgdispatcher import ( - "context" + context "context" - "github.com/milvus-io/milvus/pkg/mq/common" - "github.com/stretchr/testify/mock" + common "github.com/milvus-io/milvus/pkg/mq/common" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + mock "github.com/stretchr/testify/mock" - "github.com/milvus-io/milvus/pkg/mq/msgstream" + msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + + msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream" ) // MockClient is an autogenerated mock type for the Client type @@ -126,7 +127,7 @@ type MockClient_Register_Call struct { // - ctx context.Context // - vchannel string // - pos *msgpb.MsgPosition -// - subPos mqwrapper.SubscriptionInitialPosition +// - subPos common.SubscriptionInitialPosition func (_e *MockClient_Expecter) Register(ctx interface{}, vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call { return &MockClient_Register_Call{Call: _e.mock.On("Register", ctx, vchannel, pos, subPos)} } diff --git a/pkg/mq/msgstream/mock_msgstream.go b/pkg/mq/msgstream/mock_msgstream.go index 8d1d2dbfad1b2..84fb32526009a 100644 --- a/pkg/mq/msgstream/mock_msgstream.go +++ b/pkg/mq/msgstream/mock_msgstream.go @@ -3,12 +3,13 @@ package msgstream import ( - "context" + context "context" - "github.com/milvus-io/milvus/pkg/mq/common" - "github.com/stretchr/testify/mock" + common "github.com/milvus-io/milvus/pkg/mq/common" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + mock "github.com/stretchr/testify/mock" + + msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" ) // MockMsgStream is an autogenerated mock type for the MsgStream type @@ -47,7 +48,7 @@ type MockMsgStream_AsConsumer_Call struct { // - ctx context.Context // - channels []string // - subName string -// - position mqwrapper.SubscriptionInitialPosition +// - position common.SubscriptionInitialPosition func (_e *MockMsgStream_Expecter) AsConsumer(ctx interface{}, channels interface{}, subName interface{}, position interface{}) *MockMsgStream_AsConsumer_Call { return &MockMsgStream_AsConsumer_Call{Call: _e.mock.On("AsConsumer", ctx, channels, subName, position)} } diff --git a/pkg/streaming/util/types/pchannel_info.go b/pkg/streaming/util/types/pchannel_info.go index 1da295d639710..6a4d65c26f6a7 100644 --- a/pkg/streaming/util/types/pchannel_info.go +++ b/pkg/streaming/util/types/pchannel_info.go @@ -9,3 +9,8 @@ type PChannelInfo struct { Name string // name of pchannel. Term int64 // term of pchannel. } + +type PChannelInfoAssigned struct { + Channel PChannelInfo + Node StreamingNodeInfo +} diff --git a/pkg/streaming/util/types/streaming_node.go b/pkg/streaming/util/types/streaming_node.go new file mode 100644 index 0000000000000..0b3927721ae26 --- /dev/null +++ b/pkg/streaming/util/types/streaming_node.go @@ -0,0 +1,42 @@ +package types + +import ( + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var ( + ErrStopping = errors.New("streaming node is stopping") + ErrNotAlive = errors.New("streaming node is not alive") +) + +// VersionedStreamingNodeAssignments is the relation between server and channels with version. +type VersionedStreamingNodeAssignments struct { + Version typeutil.VersionInt64Pair + Assignments map[int64]StreamingNodeAssignment +} + +// StreamingNodeAssignment is the relation between server and channels. +type StreamingNodeAssignment struct { + NodeInfo StreamingNodeInfo + Channels []PChannelInfo +} + +// StreamingNodeInfo is the relation between server and channels. +type StreamingNodeInfo struct { + ServerID int64 + Address string +} + +// StreamingNodeStatus is the information of a streaming node. +type StreamingNodeStatus struct { + StreamingNodeInfo + // TODO: balance attributes should added here in future. + Err error +} + +// IsHealthy returns whether the streaming node is healthy. +func (n *StreamingNodeStatus) IsHealthy() bool { + return n.Err == nil +} diff --git a/pkg/streaming/util/types/streaming_node_test.go b/pkg/streaming/util/types/streaming_node_test.go new file mode 100644 index 0000000000000..579c01e596fbf --- /dev/null +++ b/pkg/streaming/util/types/streaming_node_test.go @@ -0,0 +1,15 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStreamingNodeStatus(t *testing.T) { + s := StreamingNodeStatus{Err: ErrStopping} + assert.False(t, s.IsHealthy()) + + s = StreamingNodeStatus{Err: ErrNotAlive} + assert.False(t, s.IsHealthy()) +} diff --git a/pkg/streaming/walimpls/helper/scanner_helper.go b/pkg/streaming/walimpls/helper/scanner_helper.go index 3d6ef34408e13..082d28cc84a2e 100644 --- a/pkg/streaming/walimpls/helper/scanner_helper.go +++ b/pkg/streaming/walimpls/helper/scanner_helper.go @@ -1,31 +1,28 @@ package helper -import "context" +import ( + "context" + + "github.com/milvus-io/milvus/pkg/util/syncutil" +) // NewScannerHelper creates a new ScannerHelper. func NewScannerHelper(scannerName string) *ScannerHelper { - ctx, cancel := context.WithCancel(context.Background()) return &ScannerHelper{ scannerName: scannerName, - ctx: ctx, - cancel: cancel, - finishCh: make(chan struct{}), - err: nil, + notifier: syncutil.NewAsyncTaskNotifier[error](), } } // ScannerHelper is a helper for scanner implementation. type ScannerHelper struct { scannerName string - ctx context.Context - cancel context.CancelFunc - finishCh chan struct{} - err error + notifier *syncutil.AsyncTaskNotifier[error] } // Context returns the context of the scanner, which will cancel when the scanner helper is closed. func (s *ScannerHelper) Context() context.Context { - return s.ctx + return s.notifier.Context() } // Name returns the name of the scanner. @@ -35,24 +32,21 @@ func (s *ScannerHelper) Name() string { // Error returns the error of the scanner. func (s *ScannerHelper) Error() error { - <-s.finishCh - return s.err + return s.notifier.BlockAndGetResult() } // Done returns a channel that will be closed when the scanner is finished. func (s *ScannerHelper) Done() <-chan struct{} { - return s.finishCh + return s.notifier.FinishChan() } // Close closes the scanner, block until the Finish is called. func (s *ScannerHelper) Close() error { - s.cancel() - <-s.finishCh - return s.err + s.notifier.Cancel() + return s.notifier.BlockAndGetResult() } // Finish finishes the scanner with an error. func (s *ScannerHelper) Finish(err error) { - s.err = err - close(s.finishCh) + s.notifier.Finish(err) } diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index fce8f78c8f4a2..e5c9ab174188e 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -64,16 +64,18 @@ type ComponentParam struct { GpuConfig gpuConfig TraceCfg traceConfig - RootCoordCfg rootCoordConfig - ProxyCfg proxyConfig - QueryCoordCfg queryCoordConfig - QueryNodeCfg queryNodeConfig - DataCoordCfg dataCoordConfig - DataNodeCfg dataNodeConfig - IndexNodeCfg indexNodeConfig - HTTPCfg httpConfig - LogCfg logConfig - RoleCfg roleConfig + RootCoordCfg rootCoordConfig + ProxyCfg proxyConfig + QueryCoordCfg queryCoordConfig + QueryNodeCfg queryNodeConfig + DataCoordCfg dataCoordConfig + DataNodeCfg dataNodeConfig + IndexNodeCfg indexNodeConfig + HTTPCfg httpConfig + LogCfg logConfig + RoleCfg roleConfig + StreamingCoordCfg streamingCoordConfig + StreamingNodeCfg streamingNodeConfig RootCoordGrpcServerCfg GrpcServerConfig ProxyGrpcServerCfg GrpcServerConfig @@ -125,6 +127,8 @@ func (p *ComponentParam) init(bt *BaseTable) { p.LogCfg.init(bt) p.RoleCfg.init(bt) p.GpuConfig.init(bt) + p.StreamingCoordCfg.init(bt) + p.StreamingNodeCfg.init(bt) p.RootCoordGrpcServerCfg.Init("rootCoord", bt) p.ProxyGrpcServerCfg.Init("proxy", bt) @@ -4168,6 +4172,46 @@ func (p *indexNodeConfig) init(base *BaseTable) { p.GracefulStopTimeout.Init(base.mgr) } +type streamingCoordConfig struct { + AutoBalanceTriggerInterval ParamItem `refreshable:"true"` + AutoBalanceBackoffInitialInterval ParamItem `refreshable:"true"` + AutoBalanceBackoffMultiplier ParamItem `refreshable:"true"` +} + +func (p *streamingCoordConfig) init(base *BaseTable) { + p.AutoBalanceTriggerInterval = ParamItem{ + Key: "streamingCoord.autoBalanceTriggerInterval", + Version: "2.5.0", + Doc: `The interval of balance task trigger at background, 1 min by default. +It's ok to set it into duration string, such as 30s or 1m30s, see time.ParseDuration`, + DefaultValue: "1m", + Export: true, + } + p.AutoBalanceTriggerInterval.Init(base.mgr) + p.AutoBalanceBackoffInitialInterval = ParamItem{ + Key: "streamingCoord.autoBalanceBackoffInitialInterval", + Version: "2.5.0", + Doc: `The initial interval of balance task trigger backoff, 50 ms by default. +It's ok to set it into duration string, such as 30s or 1m30s, see time.ParseDuration`, + DefaultValue: "50ms", + Export: true, + } + p.AutoBalanceBackoffInitialInterval.Init(base.mgr) + p.AutoBalanceBackoffMultiplier = ParamItem{ + Key: "streamingCoord.autoBalanceBackoffMultiplier", + Version: "2.5.0", + Doc: "The multiplier of balance task trigger backoff, 2 by default", + DefaultValue: "2", + Export: true, + } + p.AutoBalanceBackoffMultiplier.Init(base.mgr) +} + +type streamingNodeConfig struct{} + +func (p *streamingNodeConfig) init(base *BaseTable) { +} + type runtimeConfig struct { CreateTime RuntimeParamItem UpdateTime RuntimeParamItem diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index 234b0e38558c6..dc433b898bd6a 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -529,6 +529,18 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 100*time.Second, Params.GracefulStopTimeout.GetAsDuration(time.Second)) }) + t.Run("test streamingCoordConfig", func(t *testing.T) { + assert.Equal(t, 1*time.Minute, params.StreamingCoordCfg.AutoBalanceTriggerInterval.GetAsDurationByParse()) + assert.Equal(t, 50*time.Millisecond, params.StreamingCoordCfg.AutoBalanceBackoffInitialInterval.GetAsDurationByParse()) + assert.Equal(t, 2.0, params.StreamingCoordCfg.AutoBalanceBackoffMultiplier.GetAsFloat()) + params.Save(params.StreamingCoordCfg.AutoBalanceTriggerInterval.Key, "50s") + params.Save(params.StreamingCoordCfg.AutoBalanceBackoffInitialInterval.Key, "50s") + params.Save(params.StreamingCoordCfg.AutoBalanceBackoffMultiplier.Key, "3.5") + assert.Equal(t, 50*time.Second, params.StreamingCoordCfg.AutoBalanceTriggerInterval.GetAsDurationByParse()) + assert.Equal(t, 50*time.Second, params.StreamingCoordCfg.AutoBalanceBackoffInitialInterval.GetAsDurationByParse()) + assert.Equal(t, 3.5, params.StreamingCoordCfg.AutoBalanceBackoffMultiplier.GetAsFloat()) + }) + t.Run("channel config priority", func(t *testing.T) { Params := ¶ms.CommonCfg params.Save(Params.RootCoordDml.Key, "dml1") diff --git a/pkg/util/paramtable/param_item.go b/pkg/util/paramtable/param_item.go index c3f00727c8d7c..7a70440a7b712 100644 --- a/pkg/util/paramtable/param_item.go +++ b/pkg/util/paramtable/param_item.go @@ -244,6 +244,18 @@ func (pi *ParamItem) GetAsRoleDetails() map[string](map[string]([](map[string]st return getAndConvert(pi.GetValue(), funcutil.JSONToRoleDetails, nil) } +func (pi *ParamItem) GetAsDurationByParse() time.Duration { + val, _ := pi.get() + durationVal, err := time.ParseDuration(val) + if err != nil { + durationVal, err = time.ParseDuration(pi.DefaultValue) + if err != nil { + panic(fmt.Sprintf("unreachable: parse duration from default value failed, %s, err: %s", pi.DefaultValue, err.Error())) + } + } + return durationVal +} + func (pi *ParamItem) GetAsSize() int64 { valueStr := strings.ToLower(pi.GetValue()) if strings.HasSuffix(valueStr, "g") || strings.HasSuffix(valueStr, "gb") { diff --git a/pkg/util/syncutil/async_task_notifier.go b/pkg/util/syncutil/async_task_notifier.go new file mode 100644 index 0000000000000..74b6a538f5d4c --- /dev/null +++ b/pkg/util/syncutil/async_task_notifier.go @@ -0,0 +1,50 @@ +package syncutil + +import "context" + +// NewAsyncTaskNotifier creates a new async task notifier. +func NewAsyncTaskNotifier[T any]() *AsyncTaskNotifier[T] { + ctx, cancel := context.WithCancel(context.Background()) + return &AsyncTaskNotifier[T]{ + ctx: ctx, + cancel: cancel, + future: NewFuture[T](), + } +} + +// AsyncTaskNotifier is a notifier for async task. +type AsyncTaskNotifier[T any] struct { + ctx context.Context + cancel context.CancelFunc + future *Future[T] +} + +// Context returns the context of the async task. +func (n *AsyncTaskNotifier[T]) Context() context.Context { + return n.ctx +} + +// Cancel cancels the async task, the async task can receive the cancel signal from Context. +func (n *AsyncTaskNotifier[T]) Cancel() { + n.cancel() +} + +// BlockAndGetResult returns the result of the async task. +func (n *AsyncTaskNotifier[T]) BlockAndGetResult() T { + return n.future.Get() +} + +// BlockUntilFinish blocks until the async task is finished. +func (n *AsyncTaskNotifier[T]) BlockUntilFinish() { + <-n.future.Done() +} + +// FinishChan returns a channel that will be closed when the async task is finished. +func (n *AsyncTaskNotifier[T]) FinishChan() <-chan struct{} { + return n.future.Done() +} + +// Finish finishes the async task with a result. +func (n *AsyncTaskNotifier[T]) Finish(result T) { + n.future.Set(result) +} diff --git a/pkg/util/syncutil/async_task_notifier_test.go b/pkg/util/syncutil/async_task_notifier_test.go new file mode 100644 index 0000000000000..b88ad0b81bd15 --- /dev/null +++ b/pkg/util/syncutil/async_task_notifier_test.go @@ -0,0 +1,57 @@ +package syncutil + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAsyncTaskNotifier(t *testing.T) { + n := NewAsyncTaskNotifier[error]() + assert.NotNil(t, n.Context()) + + select { + case <-n.FinishChan(): + t.Errorf("should not done") + return + case <-n.Context().Done(): + t.Error("should not cancel") + return + default: + } + + finishErr := errors.New("test") + + ch := make(chan struct{}) + go func() { + defer close(ch) + done := false + cancel := false + cancelCh := n.Context().Done() + doneCh := n.FinishChan() + for i := 0; ; i += 1 { + select { + case <-doneCh: + done = true + doneCh = nil + case <-cancelCh: + cancel = true + cancelCh = nil + n.Finish(finishErr) + } + if cancel && done { + return + } + if i == 0 { + assert.True(t, cancel && !done) + } else if i == 1 { + assert.True(t, cancel && done) + } + } + }() + n.Cancel() + n.BlockUntilFinish() + assert.ErrorIs(t, n.BlockAndGetResult(), finishErr) + <-ch +} diff --git a/pkg/util/typeutil/version.go b/pkg/util/typeutil/version.go new file mode 100644 index 0000000000000..31733fcbaae39 --- /dev/null +++ b/pkg/util/typeutil/version.go @@ -0,0 +1,56 @@ +package typeutil + +// Version is a interface for version comparison. +type Version interface { + // GT returns true if v > v2. + GT(Version) bool + + // EQ returns true if v == v2. + EQ(Version) bool +} + +// VersionInt64 is a int64 type version. +type VersionInt64 int64 + +func (v VersionInt64) GT(v2 Version) bool { + return v > mustCastVersionInt64(v2) +} + +func (v VersionInt64) EQ(v2 Version) bool { + return v == mustCastVersionInt64(v2) +} + +func mustCastVersionInt64(v2 Version) VersionInt64 { + if v2i, ok := v2.(VersionInt64); ok { + return v2i + } else if v2i, ok := v2.(*VersionInt64); ok { + return *v2i + } + panic("invalid version type") +} + +// VersionInt64Pair is a pair of int64 type version. +// It's easy to be used in multi node version comparison. +type VersionInt64Pair struct { + Global int64 + Local int64 +} + +func (v VersionInt64Pair) GT(v2 Version) bool { + vPair := mustCastVersionInt64Pair(v2) + return v.Global > vPair.Global || (v.Global == vPair.Global && v.Local > vPair.Local) +} + +func (v VersionInt64Pair) EQ(v2 Version) bool { + vPair := mustCastVersionInt64Pair(v2) + return v.Global == vPair.Global && v.Local == vPair.Local +} + +func mustCastVersionInt64Pair(v2 Version) VersionInt64Pair { + if v2i, ok := v2.(VersionInt64Pair); ok { + return v2i + } else if v2i, ok := v2.(*VersionInt64Pair); ok { + return *v2i + } + panic("invalid version type") +} diff --git a/pkg/util/typeutil/version_test.go b/pkg/util/typeutil/version_test.go new file mode 100644 index 0000000000000..594d5e4d7071d --- /dev/null +++ b/pkg/util/typeutil/version_test.go @@ -0,0 +1,29 @@ +package typeutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVersion(t *testing.T) { + assert.True(t, VersionInt64(1).GT(VersionInt64(0))) + assert.True(t, VersionInt64(0).EQ(VersionInt64(0))) + v := VersionInt64(0) + assert.True(t, VersionInt64(1).GT(&v)) + assert.True(t, VersionInt64(0).EQ(&v)) + assert.Panics(t, func() { + VersionInt64(0).GT(VersionInt64Pair{Global: 1, Local: 1}) + }) + + assert.True(t, VersionInt64Pair{Global: 1, Local: 2}.GT(VersionInt64Pair{Global: 1, Local: 1})) + assert.True(t, VersionInt64Pair{Global: 2, Local: 0}.GT(VersionInt64Pair{Global: 1, Local: 1})) + assert.True(t, VersionInt64Pair{Global: 1, Local: 1}.EQ(VersionInt64Pair{Global: 1, Local: 1})) + v2 := VersionInt64Pair{Global: 1, Local: 1} + assert.True(t, VersionInt64Pair{Global: 1, Local: 2}.GT(&v2)) + assert.True(t, VersionInt64Pair{Global: 2, Local: 0}.GT(&v2)) + assert.True(t, VersionInt64Pair{Global: 1, Local: 1}.EQ(&v2)) + assert.Panics(t, func() { + VersionInt64Pair{Global: 1, Local: 2}.GT(VersionInt64(0)) + }) +}