From bd640754ac7130c1afef4f049b6720abe59f82c8 Mon Sep 17 00:00:00 2001 From: sre-ci-robot <56469371+sre-ci-robot@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:44:39 +0800 Subject: [PATCH 01/14] enhance: [skip e2e] skip maximize build space plugin if it is self-hosted runner (#29214) Signed-off-by: Sammy Huang Co-authored-by: Sammy Huang --- .github/workflows/code-checker.yaml | 2 ++ .github/workflows/main.yaml | 1 + .github/workflows/publish-builder.yaml | 1 + .github/workflows/publish-gpu-builder.yaml | 1 + 4 files changed, 5 insertions(+) diff --git a/.github/workflows/code-checker.yaml b/.github/workflows/code-checker.yaml index 0b55d7533e781..78ef9833ed84a 100644 --- a/.github/workflows/code-checker.yaml +++ b/.github/workflows/code-checker.yaml @@ -41,6 +41,7 @@ jobs: steps: - name: Maximize build space uses: easimon/maximize-build-space@master + if: ${{ ! startsWith(runner.name, 'self') }} # skip this step if it is self-hosted runner with: root-reserve-mb: 20480 swap-size-mb: 1024 @@ -88,6 +89,7 @@ jobs: steps: - name: Maximize build space uses: easimon/maximize-build-space@master + if: ${{ ! startsWith(runner.name, 'self') }} # skip this step if it is self-hosted runner with: root-reserve-mb: 20480 swap-size-mb: 1024 diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 325fb64da249a..ce8f23d6ba94f 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -56,6 +56,7 @@ jobs: - name: Maximize build space uses: easimon/maximize-build-space@master + if: ${{ ! startsWith(runner.name, 'self') }} # skip this step if it is self-hosted runner with: root-reserve-mb: 20480 swap-size-mb: 1024 diff --git a/.github/workflows/publish-builder.yaml b/.github/workflows/publish-builder.yaml index 377efcdb23cd8..218d326d78c64 100644 --- a/.github/workflows/publish-builder.yaml +++ b/.github/workflows/publish-builder.yaml @@ -36,6 +36,7 @@ jobs: steps: - name: Maximize build space uses: easimon/maximize-build-space@master + if: ${{ ! startsWith(runner.name, 'self') }} # skip this step if it is self-hosted runner with: root-reserve-mb: 20480 # overprovision-lvm: 'true' diff --git a/.github/workflows/publish-gpu-builder.yaml b/.github/workflows/publish-gpu-builder.yaml index 6e0ca62fad3c3..48bc27d6beb97 100644 --- a/.github/workflows/publish-gpu-builder.yaml +++ b/.github/workflows/publish-gpu-builder.yaml @@ -36,6 +36,7 @@ jobs: steps: - name: Maximize build space uses: easimon/maximize-build-space@master + if: ${{ ! startsWith(runner.name, 'self') }} # skip this step if it is self-hosted runner with: root-reserve-mb: 20480 # overprovision-lvm: 'true' From b8674811cf7c57b8a73bb2121c256a8873ce5d49 Mon Sep 17 00:00:00 2001 From: yah01 Date: Thu, 14 Dec 2023 18:22:39 +0800 Subject: [PATCH 02/14] fix: data race in ProxyClientManager (#29206) this PR changed the ProxyClientManager to thread-safe fix #29205 Signed-off-by: yah01 --- internal/rootcoord/mock_test.go | 8 +- internal/rootcoord/proxy_client_manager.go | 106 +++++++----------- .../rootcoord/proxy_client_manager_test.go | 57 ++++------ internal/rootcoord/quota_center_test.go | 5 +- internal/rootcoord/root_coord.go | 10 +- 5 files changed, 76 insertions(+), 110 deletions(-) diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index 679189ff1d37d..6e06bf2ec5e0f 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -386,7 +386,7 @@ func newTestCore(opts ...Opt) *Core { func withValidProxyManager() Opt { return func(c *Core) { c.proxyClientManager = &proxyClientManager{ - proxyClient: make(map[UniqueID]types.ProxyClient), + proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient](), } p := newMockProxy() p.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { @@ -398,14 +398,14 @@ func withValidProxyManager() Opt { Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, }, nil } - c.proxyClientManager.proxyClient[TestProxyID] = p + c.proxyClientManager.proxyClient.Insert(TestProxyID, p) } } func withInvalidProxyManager() Opt { return func(c *Core) { c.proxyClientManager = &proxyClientManager{ - proxyClient: make(map[UniqueID]types.ProxyClient), + proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient](), } p := newMockProxy() p.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { @@ -417,7 +417,7 @@ func withInvalidProxyManager() Opt { Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, }, nil } - c.proxyClientManager.proxyClient[TestProxyID] = p + c.proxyClientManager.proxyClient.Insert(TestProxyID, p) } } diff --git a/internal/rootcoord/proxy_client_manager.go b/internal/rootcoord/proxy_client_manager.go index f200199c9d12e..0141ef4a55923 100644 --- a/internal/rootcoord/proxy_client_manager.go +++ b/internal/rootcoord/proxy_client_manager.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type proxyCreator func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) @@ -49,8 +50,7 @@ func DefaultProxyCreator(ctx context.Context, addr string, nodeID int64) (types. type proxyClientManager struct { creator proxyCreator - lock sync.RWMutex - proxyClient map[int64]types.ProxyClient + proxyClient *typeutil.ConcurrentMap[int64, types.ProxyClient] helper proxyClientManagerHelper } @@ -65,7 +65,7 @@ var defaultClientManagerHelper = proxyClientManagerHelper{ func newProxyClientManager(creator proxyCreator) *proxyClientManager { return &proxyClientManager{ creator: creator, - proxyClient: make(map[int64]types.ProxyClient), + proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient](), helper: defaultClientManagerHelper, } } @@ -76,16 +76,12 @@ func (p *proxyClientManager) AddProxyClients(sessions []*sessionutil.Session) { } } -func (p *proxyClientManager) GetProxyClients() map[int64]types.ProxyClient { - p.lock.RLock() - defer p.lock.RUnlock() +func (p *proxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] { return p.proxyClient } func (p *proxyClientManager) AddProxyClient(session *sessionutil.Session) { - p.lock.RLock() - _, ok := p.proxyClient[session.ServerID] - p.lock.RUnlock() + _, ok := p.proxyClient.Get(session.ServerID) if ok { return } @@ -96,15 +92,12 @@ func (p *proxyClientManager) AddProxyClient(session *sessionutil.Session) { // GetProxyCount returns number of proxy clients. func (p *proxyClientManager) GetProxyCount() int { - p.lock.Lock() - defer p.lock.Unlock() - - return len(p.proxyClient) + return p.proxyClient.Len() } // mutex.Lock is required before calling this method. func (p *proxyClientManager) updateProxyNumMetric() { - metrics.RootCoordProxyCounter.WithLabelValues().Set(float64(len(p.proxyClient))) + metrics.RootCoordProxyCounter.WithLabelValues().Set(float64(p.proxyClient.Len())) } func (p *proxyClientManager) connect(session *sessionutil.Session) { @@ -114,51 +107,40 @@ func (p *proxyClientManager) connect(session *sessionutil.Session) { return } - p.lock.Lock() - defer p.lock.Unlock() - - _, ok := p.proxyClient[session.ServerID] + _, ok := p.proxyClient.GetOrInsert(session.GetServerID(), pc) if ok { pc.Close() return } - p.proxyClient[session.ServerID] = pc log.Info("succeed to create proxy client", zap.String("address", session.Address), zap.Int64("serverID", session.ServerID)) p.helper.afterConnect() } func (p *proxyClientManager) DelProxyClient(s *sessionutil.Session) { - p.lock.Lock() - defer p.lock.Unlock() - - cli, ok := p.proxyClient[s.ServerID] + cli, ok := p.proxyClient.GetAndRemove(s.GetServerID()) if ok { cli.Close() } - delete(p.proxyClient, s.ServerID) p.updateProxyNumMetric() log.Info("remove proxy client", zap.String("proxy address", s.Address), zap.Int64("proxy id", s.ServerID)) } func (p *proxyClientManager) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...expireCacheOpt) error { - p.lock.Lock() - defer p.lock.Unlock() - c := defaultExpireCacheConfig() for _, opt := range opts { opt(&c) } c.apply(request) - if len(p.proxyClient) == 0 { + if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, InvalidateCollectionMetaCache will not send to any client") return nil } group := &errgroup.Group{} - for k, v := range p.proxyClient { - k, v := k, v + p.proxyClient.Range(func(key int64, value types.ProxyClient) bool { + k, v := key, value group.Go(func() error { sta, err := v.InvalidateCollectionMetaCache(ctx, request) if err != nil { @@ -173,23 +155,21 @@ func (p *proxyClientManager) InvalidateCollectionMetaCache(ctx context.Context, } return nil }) - } + return true + }) return group.Wait() } // InvalidateCredentialCache TODO: too many codes similar to InvalidateCollectionMetaCache. func (p *proxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error { - p.lock.Lock() - defer p.lock.Unlock() - - if len(p.proxyClient) == 0 { + if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, InvalidateCredentialCache will not send to any client") return nil } group := &errgroup.Group{} - for k, v := range p.proxyClient { - k, v := k, v + p.proxyClient.Range(func(key int64, value types.ProxyClient) bool { + k, v := key, value group.Go(func() error { sta, err := v.InvalidateCredentialCache(ctx, request) if err != nil { @@ -200,23 +180,22 @@ func (p *proxyClientManager) InvalidateCredentialCache(ctx context.Context, requ } return nil }) - } + return true + }) + return group.Wait() } // UpdateCredentialCache TODO: too many codes similar to InvalidateCollectionMetaCache. func (p *proxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error { - p.lock.Lock() - defer p.lock.Unlock() - - if len(p.proxyClient) == 0 { + if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, UpdateCredentialCache will not send to any client") return nil } group := &errgroup.Group{} - for k, v := range p.proxyClient { - k, v := k, v + p.proxyClient.Range(func(key int64, value types.ProxyClient) bool { + k, v := key, value group.Go(func() error { sta, err := v.UpdateCredentialCache(ctx, request) if err != nil { @@ -227,23 +206,21 @@ func (p *proxyClientManager) UpdateCredentialCache(ctx context.Context, request } return nil }) - } + return true + }) return group.Wait() } // RefreshPolicyInfoCache TODO: too many codes similar to InvalidateCollectionMetaCache. func (p *proxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error { - p.lock.Lock() - defer p.lock.Unlock() - - if len(p.proxyClient) == 0 { + if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, RefreshPrivilegeInfoCache will not send to any client") return nil } group := &errgroup.Group{} - for k, v := range p.proxyClient { - k, v := k, v + p.proxyClient.Range(func(key int64, value types.ProxyClient) bool { + k, v := key, value group.Go(func() error { status, err := v.RefreshPolicyInfoCache(ctx, req) if err != nil { @@ -254,16 +231,14 @@ func (p *proxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *pr } return nil }) - } + return true + }) return group.Wait() } // GetProxyMetrics sends requests to proxies to get metrics. func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) { - p.lock.Lock() - defer p.lock.Unlock() - - if len(p.proxyClient) == 0 { + if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, GetMetrics will not send to any client") return nil, nil } @@ -276,8 +251,8 @@ func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.G group := &errgroup.Group{} var metricRspsMu sync.Mutex metricRsps := make([]*milvuspb.GetMetricsResponse, 0) - for k, v := range p.proxyClient { - k, v := k, v + p.proxyClient.Range(func(key int64, value types.ProxyClient) bool { + k, v := key, value group.Go(func() error { rsp, err := v.GetProxyMetrics(ctx, req) if err != nil { @@ -291,7 +266,8 @@ func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.G metricRspsMu.Unlock() return nil }) - } + return true + }) err = group.Wait() if err != nil { return nil, err @@ -301,17 +277,14 @@ func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.G // SetRates notifies Proxy to limit rates of requests. func (p *proxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error { - p.lock.Lock() - defer p.lock.Unlock() - - if len(p.proxyClient) == 0 { + if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, SetRates will not send to any client") return nil } group := &errgroup.Group{} - for k, v := range p.proxyClient { - k, v := k, v + p.proxyClient.Range(func(key int64, value types.ProxyClient) bool { + k, v := key, value group.Go(func() error { sta, err := v.SetRates(ctx, request) if err != nil { @@ -322,6 +295,7 @@ func (p *proxyClientManager) SetRates(ctx context.Context, request *proxypb.SetR } return nil }) - } + return true + }) return group.Wait() } diff --git a/internal/rootcoord/proxy_client_manager_test.go b/internal/rootcoord/proxy_client_manager_test.go index 3edd271562bde..dc3a6dbe17f76 100644 --- a/internal/rootcoord/proxy_client_manager_test.go +++ b/internal/rootcoord/proxy_client_manager_test.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type proxyMock struct { @@ -164,7 +165,7 @@ func TestProxyClientManager_AddProxyClient(t *testing.T) { func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { t.Run("empty proxy list", func(t *testing.T) { ctx := context.Background() - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{}} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) assert.NoError(t, err) }) @@ -175,9 +176,8 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { return merr.Success(), errors.New("error mock InvalidateCollectionMetaCache") } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) assert.Error(t, err) }) @@ -189,9 +189,8 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { return merr.Status(mockErr), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) assert.Error(t, err) }) @@ -202,9 +201,8 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { return nil, merr.ErrNodeNotFound } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) assert.NoError(t, err) @@ -216,9 +214,8 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { return merr.Success(), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) assert.NoError(t, err) }) @@ -227,7 +224,7 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { t.Run("empty proxy list", func(t *testing.T) { ctx := context.Background() - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{}} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) assert.NoError(t, err) }) @@ -238,9 +235,8 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { return merr.Success(), errors.New("error mock InvalidateCredentialCache") } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) assert.Error(t, err) }) @@ -252,9 +248,8 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { return merr.Status(mockErr), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) assert.Error(t, err) }) @@ -265,9 +260,8 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { return merr.Success(), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) assert.NoError(t, err) }) @@ -276,7 +270,7 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { t.Run("empty proxy list", func(t *testing.T) { ctx := context.Background() - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{}} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) assert.NoError(t, err) }) @@ -287,9 +281,8 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { return merr.Success(), errors.New("error mock RefreshPolicyInfoCache") } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) assert.Error(t, err) }) @@ -301,9 +294,8 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { return merr.Status(mockErr), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) assert.Error(t, err) }) @@ -314,9 +306,8 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { return merr.Success(), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) assert.NoError(t, err) }) diff --git a/internal/rootcoord/quota_center_test.go b/internal/rootcoord/quota_center_test.go index e14b6fb3e6b2c..670ec78212827 100644 --- a/internal/rootcoord/quota_center_test.go +++ b/internal/rootcoord/quota_center_test.go @@ -533,9 +533,8 @@ func TestQuotaCenter(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) p1 := mocks.NewMockProxyClient(t) p1.EXPECT().SetRates(mock.Anything, mock.Anything).Return(nil, nil) - pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ - TestProxyID: p1, - }} + pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} + pcm.proxyClient.Insert(TestProxyID, p1) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index f093d0f1f9613..d39c78ebf6f8a 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -2782,9 +2782,10 @@ func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest) group, ctx := errgroup.WithContext(ctx) errReasons := make([]string, 0, c.proxyClientManager.GetProxyCount()) - for nodeID, proxyClient := range c.proxyClientManager.GetProxyClients() { - nodeID := nodeID - proxyClient := proxyClient + proxyClients := c.proxyClientManager.GetProxyClients() + proxyClients.Range(func(key int64, value types.ProxyClient) bool { + nodeID := key + proxyClient := value group.Go(func() error { sta, err := proxyClient.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) if err != nil { @@ -2799,7 +2800,8 @@ func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest) } return nil }) - } + return true + }) err := group.Wait() if err != nil || len(errReasons) != 0 { From 13beb5ccc0382356688780285e7c2e491a8a4f38 Mon Sep 17 00:00:00 2001 From: yah01 Date: Thu, 14 Dec 2023 18:28:38 +0800 Subject: [PATCH 03/14] fix: load gets stuck probably (#29191) we found the load got stuck probably, and reviewed the logs. the target observer seems not working, the reason is the taskDispatcher removes the task in a goroutine, and modifies the task status after committing the task into the goroutine pool, but this may happen after the task removed, which leads to the task will never be removed related #29086 Signed-off-by: yah01 --- internal/querycoordv2/observers/task_dispatcher.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/querycoordv2/observers/task_dispatcher.go b/internal/querycoordv2/observers/task_dispatcher.go index 720d415422003..29ede76b1bb28 100644 --- a/internal/querycoordv2/observers/task_dispatcher.go +++ b/internal/querycoordv2/observers/task_dispatcher.go @@ -96,12 +96,12 @@ func (d *taskDispatcher[K]) schedule(ctx context.Context) { case <-d.notifyCh: d.tasks.Range(func(k K, submitted bool) bool { if !submitted { + d.tasks.Insert(k, true) d.pool.Submit(func() (any, error) { d.taskRunner(ctx, k) d.tasks.Remove(k) return struct{}{}, nil }) - d.tasks.Insert(k, true) } return true }) From cc727ace61b5d115377cb3f2474bf343bb952cd6 Mon Sep 17 00:00:00 2001 From: XuanYang-cn Date: Thu, 14 Dec 2023 18:46:40 +0800 Subject: [PATCH 04/14] fix: Set compacted segments' level to level one (#29190) Signed-off-by: yangxuan --- internal/datacoord/meta.go | 2 ++ internal/datacoord/meta_test.go | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/datacoord/meta.go b/internal/datacoord/meta.go index 4b95011272bc5..29070d7388a1c 100644 --- a/internal/datacoord/meta.go +++ b/internal/datacoord/meta.go @@ -1041,6 +1041,7 @@ func (m *meta) PrepareCompleteCompactionMutation(plan *datapb.CompactionPlan, CreatedByCompaction: true, CompactionFrom: compactionFrom, LastExpireTime: plan.GetStartTime(), + Level: datapb.SegmentLevel_L1, } segment := NewSegmentInfo(segmentInfo) metricMutation.addNewSeg(segment.GetState(), segment.GetLevel(), segment.GetNumOfRows()) @@ -1048,6 +1049,7 @@ func (m *meta) PrepareCompleteCompactionMutation(plan *datapb.CompactionPlan, zap.Int64("collectionID", segment.GetCollectionID()), zap.Int64("partitionID", segment.GetPartitionID()), zap.Int64("new segment ID", segment.GetID()), + zap.String("new segment level", segment.GetLevel().String()), zap.Int64("new segment num of rows", segment.GetNumOfRows()), zap.Any("compacted from", segment.GetCompactionFrom())) diff --git a/internal/datacoord/meta_test.go b/internal/datacoord/meta_test.go index da385608dae82..7d8ab5bb45021 100644 --- a/internal/datacoord/meta_test.go +++ b/internal/datacoord/meta_test.go @@ -730,7 +730,8 @@ func TestMeta_PrepareCompleteCompactionMutation(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, afterCompact) assert.NotNil(t, newSegment) - assert.Equal(t, 3, len(metricMutation.stateChange[datapb.SegmentLevel_Legacy.String()])) + assert.Equal(t, 2, len(metricMutation.stateChange[datapb.SegmentLevel_Legacy.String()])) + assert.Equal(t, 1, len(metricMutation.stateChange[datapb.SegmentLevel_L1.String()])) assert.Equal(t, int64(0), metricMutation.rowCountChange) assert.Equal(t, int64(2), metricMutation.rowCountAccChange) From 8a63e53421436fdcb5e05ee71bb40091205bdbf9 Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 14 Dec 2023 19:26:39 +0800 Subject: [PATCH 05/14] enhance: Add http method to control datacoord garbage collection (#29052) See also #29051 --------- Signed-off-by: Congqi Xia --- internal/datacoord/garbage_collector.go | 74 +++++++- internal/datacoord/garbage_collector_test.go | 132 ++++++++++++++ internal/datacoord/server.go | 38 ++++ internal/datacoord/services.go | 43 +++++ internal/datacoord/services_test.go | 92 ++++++++++ .../distributed/datacoord/client/client.go | 6 + .../datacoord/client/client_test.go | 37 ++++ internal/distributed/datacoord/service.go | 4 + .../distributed/datacoord/service_test.go | 8 + internal/http/server.go | 8 + internal/mocks/mock_datacoord.go | 55 ++++++ internal/mocks/mock_datacoord_client.go | 70 ++++++++ internal/proto/data_coord.proto | 14 ++ internal/proxy/management.go | 93 ++++++++++ internal/proxy/management_test.go | 163 ++++++++++++++++++ 15 files changed, 833 insertions(+), 4 deletions(-) create mode 100644 internal/proxy/management.go create mode 100644 internal/proxy/management_test.go diff --git a/internal/datacoord/garbage_collector.go b/internal/datacoord/garbage_collector.go index cdec36bda3dff..ccf3e96d6e237 100644 --- a/internal/datacoord/garbage_collector.go +++ b/internal/datacoord/garbage_collector.go @@ -27,6 +27,7 @@ import ( "github.com/minio/minio-go/v7" "github.com/samber/lo" + "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -35,6 +36,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -56,10 +58,17 @@ type garbageCollector struct { meta *meta handler Handler - startOnce sync.Once - stopOnce sync.Once - wg sync.WaitGroup - closeCh chan struct{} + startOnce sync.Once + stopOnce sync.Once + wg sync.WaitGroup + closeCh chan struct{} + cmdCh chan gcCmd + pauseUntil atomic.Time +} +type gcCmd struct { + cmdType datapb.GcCommand + duration time.Duration + done chan struct{} } // newGarbageCollector create garbage collector with meta and option @@ -71,6 +80,7 @@ func newGarbageCollector(meta *meta, handler Handler, opt GcOption) *garbageColl handler: handler, option: opt, closeCh: make(chan struct{}), + cmdCh: make(chan gcCmd), } } @@ -88,6 +98,43 @@ func (gc *garbageCollector) start() { } } +func (gc *garbageCollector) Pause(ctx context.Context, pauseDuration time.Duration) error { + if !gc.option.enabled { + log.Info("garbage collection not enabled") + return nil + } + done := make(chan struct{}) + select { + case gc.cmdCh <- gcCmd{ + cmdType: datapb.GcCommand_Pause, + duration: pauseDuration, + done: done, + }: + <-done + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (gc *garbageCollector) Resume(ctx context.Context) error { + if !gc.option.enabled { + log.Warn("garbage collection not enabled, cannot resume") + return merr.WrapErrServiceUnavailable("garbage collection not enabled") + } + done := make(chan struct{}) + select { + case gc.cmdCh <- gcCmd{ + cmdType: datapb.GcCommand_Resume, + done: done, + }: + <-done + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // work contains actual looping check logic func (gc *garbageCollector) work() { defer gc.wg.Done() @@ -96,11 +143,30 @@ func (gc *garbageCollector) work() { for { select { case <-ticker.C: + if time.Now().Before(gc.pauseUntil.Load()) { + log.Info("garbage collector paused", zap.Time("until", gc.pauseUntil.Load())) + continue + } gc.clearEtcd() gc.recycleUnusedIndexes() gc.recycleUnusedSegIndexes() gc.scan() gc.recycleUnusedIndexFiles() + case cmd := <-gc.cmdCh: + switch cmd.cmdType { + case datapb.GcCommand_Pause: + pauseUntil := time.Now().Add(cmd.duration) + if pauseUntil.After(gc.pauseUntil.Load()) { + log.Info("garbage collection paused", zap.Duration("duration", cmd.duration), zap.Time("pauseUntil", pauseUntil)) + gc.pauseUntil.Store(pauseUntil) + } else { + log.Info("new pause until before current value", zap.Duration("duration", cmd.duration), zap.Time("pauseUntil", pauseUntil), zap.Time("oldPauseUntil", gc.pauseUntil.Load())) + } + case datapb.GcCommand_Resume: + // reset to zero value + gc.pauseUntil.Store(time.Time{}) + } + close(cmd.done) case <-gc.closeCh: log.Warn("garbage collector quit") return diff --git a/internal/datacoord/garbage_collector_test.go b/internal/datacoord/garbage_collector_test.go index 1999b9dd5086a..5442ea55383a8 100644 --- a/internal/datacoord/garbage_collector_test.go +++ b/internal/datacoord/garbage_collector_test.go @@ -32,6 +32,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" @@ -1134,3 +1135,134 @@ func TestGarbageCollector_clearETCD(t *testing.T) { segB = gc.meta.GetSegment(segID + 1) assert.Nil(t, segB) } + +type GarbageCollectorSuite struct { + suite.Suite + + bucketName string + rootPath string + + cli *storage.MinioChunkManager + inserts []string + stats []string + delta []string + others []string + + meta *meta +} + +func (s *GarbageCollectorSuite) SetupTest() { + s.bucketName = `datacoord-ut` + strings.ToLower(funcutil.RandomString(8)) + s.rootPath = `gc` + funcutil.RandomString(8) + + var err error + s.cli, s.inserts, s.stats, s.delta, s.others, err = initUtOSSEnv(s.bucketName, s.rootPath, 4) + s.Require().NoError(err) + + s.meta, err = newMemoryMeta() + s.Require().NoError(err) +} + +func (s *GarbageCollectorSuite) TearDownTest() { + cleanupOSS(s.cli.Client, s.bucketName, s.rootPath) +} + +func (s *GarbageCollectorSuite) TestPauseResume() { + s.Run("not_enabled", func() { + gc := newGarbageCollector(s.meta, newMockHandler(), GcOption{ + cli: s.cli, + enabled: false, + checkInterval: time.Millisecond * 10, + missingTolerance: time.Hour * 24, + dropTolerance: time.Hour * 24, + }) + + gc.start() + defer gc.close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := gc.Pause(ctx, time.Second) + s.NoError(err) + + err = gc.Resume(ctx) + s.Error(err) + }) + + s.Run("pause_then_resume", func() { + gc := newGarbageCollector(s.meta, newMockHandler(), GcOption{ + cli: s.cli, + enabled: true, + checkInterval: time.Millisecond * 10, + missingTolerance: time.Hour * 24, + dropTolerance: time.Hour * 24, + }) + + gc.start() + defer gc.close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := gc.Pause(ctx, time.Minute) + s.NoError(err) + + s.NotZero(gc.pauseUntil.Load()) + + err = gc.Resume(ctx) + s.NoError(err) + + s.Zero(gc.pauseUntil.Load()) + }) + + s.Run("pause_before_until", func() { + gc := newGarbageCollector(s.meta, newMockHandler(), GcOption{ + cli: s.cli, + enabled: true, + checkInterval: time.Millisecond * 10, + missingTolerance: time.Hour * 24, + dropTolerance: time.Hour * 24, + }) + + gc.start() + defer gc.close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := gc.Pause(ctx, time.Minute) + s.NoError(err) + + until := gc.pauseUntil.Load() + s.NotZero(until) + + err = gc.Pause(ctx, time.Second) + s.NoError(err) + + second := gc.pauseUntil.Load() + + s.Equal(until, second) + }) + + s.Run("pause_resume_timeout", func() { + gc := newGarbageCollector(s.meta, newMockHandler(), GcOption{ + cli: s.cli, + enabled: true, + checkInterval: time.Millisecond * 10, + missingTolerance: time.Hour * 24, + dropTolerance: time.Hour * 24, + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + err := gc.Pause(ctx, time.Minute) + s.Error(err) + + s.Zero(gc.pauseUntil.Load()) + + err = gc.Resume(ctx) + s.Error(err) + + s.Zero(gc.pauseUntil.Load()) + }) +} + +func TestGarbageCollector(t *testing.T) { + suite.Run(t, new(GarbageCollectorSuite)) +} diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 9d8eed2c861cc..f39b62dac8db1 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -405,6 +405,44 @@ func (s *Server) startDataCoord() { s.compactionViewManager.Start() } s.startServerLoop() + + // http.Register(&http.Handler{ + // Path: "/datacoord/garbage_collection/pause", + // HandlerFunc: func(w http.ResponseWriter, req *http.Request) { + // pauseSeconds := req.URL.Query().Get("pause_seconds") + // seconds, err := strconv.ParseInt(pauseSeconds, 10, 64) + // if err != nil { + // w.WriteHeader(400) + // w.Write([]byte(fmt.Sprintf(`{"msg": "invalid pause seconds(%v)"}`, pauseSeconds))) + // return + // } + + // err = s.garbageCollector.Pause(req.Context(), time.Duration(seconds)*time.Second) + // if err != nil { + // w.WriteHeader(500) + // w.Write([]byte(fmt.Sprintf(`{"msg": "failed to pause garbage collection, %s"}`, err.Error()))) + // return + // } + // w.WriteHeader(200) + // w.Write([]byte(`{"msg": "OK"}`)) + // return + // }, + // }) + // http.Register(&http.Handler{ + // Path: "/datacoord/garbage_collection/resume", + // HandlerFunc: func(w http.ResponseWriter, req *http.Request) { + // err := s.garbageCollector.Resume(req.Context()) + // if err != nil { + // w.WriteHeader(500) + // w.Write([]byte(fmt.Sprintf(`{"msg": "failed to pause garbage collection, %s"}`, err.Error()))) + // return + // } + // w.WriteHeader(200) + // w.Write([]byte(`{"msg": "OK"}`)) + // return + // }, + // }) + s.afterStart() s.stateCode.Store(commonpb.StateCode_Healthy) sessionutil.SaveServerInfo(typeutil.DataCoordRole, s.session.GetServerID()) diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 386235e4cbcd8..3ba22e7c6de84 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -21,6 +21,7 @@ import ( "fmt" "math/rand" "strconv" + "time" "github.com/cockroachdb/errors" "github.com/samber/lo" @@ -1637,3 +1638,45 @@ func (s *Server) GcConfirm(ctx context.Context, request *datapb.GcConfirmRequest resp.GcFinished = s.meta.GcConfirm(ctx, request.GetCollectionId(), request.GetPartitionId()) return resp, nil } + +func (s *Server) GcControl(ctx context.Context, request *datapb.GcControlRequest) (*commonpb.Status, error) { + status := &commonpb.Status{} + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return merr.Status(err), nil + } + + switch request.GetCommand() { + case datapb.GcCommand_Pause: + kv := lo.FindOrElse(request.GetParams(), nil, func(kv *commonpb.KeyValuePair) bool { + return kv.GetKey() == "duration" + }) + if kv == nil { + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = "pause duration param not found" + return status, nil + } + pauseSeconds, err := strconv.ParseInt(kv.GetValue(), 10, 64) + if err != nil { + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = fmt.Sprintf("pause duration not valid, %s", err.Error()) + return status, nil + } + if err := s.garbageCollector.Pause(ctx, time.Duration(pauseSeconds)*time.Second); err != nil { + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = fmt.Sprintf("failed to pause gc, %s", err.Error()) + return status, nil + } + case datapb.GcCommand_Resume: + if err := s.garbageCollector.Resume(ctx); err != nil { + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = fmt.Sprintf("failed to pause gc, %s", err.Error()) + return status, nil + } + default: + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = fmt.Sprintf("unknown gc command: %d", request.GetCommand()) + return status, nil + } + + return status, nil +} diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 592a7326c0f6e..119c8c8f605b1 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" @@ -605,3 +606,94 @@ func TestGetRecoveryInfoV2(t *testing.T) { assert.ErrorIs(t, err, merr.ErrServiceNotReady) }) } + +type GcControlServiceSuite struct { + suite.Suite + + server *Server +} + +func (s *GcControlServiceSuite) SetupTest() { + s.server = newTestServer(s.T(), nil) +} + +func (s *GcControlServiceSuite) TearDownTest() { + if s.server != nil { + closeTestServer(s.T(), s.server) + } +} + +func (s *GcControlServiceSuite) TestClosedServer() { + closeTestServer(s.T(), s.server) + resp, err := s.server.GcControl(context.TODO(), &datapb.GcControlRequest{}) + s.NoError(err) + s.False(merr.Ok(resp)) + s.server = nil +} + +func (s *GcControlServiceSuite) TestUnknownCmd() { + resp, err := s.server.GcControl(context.TODO(), &datapb.GcControlRequest{ + Command: 0, + }) + s.NoError(err) + s.False(merr.Ok(resp)) +} + +func (s *GcControlServiceSuite) TestPause() { + resp, err := s.server.GcControl(context.TODO(), &datapb.GcControlRequest{ + Command: datapb.GcCommand_Pause, + }) + s.Nil(err) + s.False(merr.Ok(resp)) + + resp, err = s.server.GcControl(context.TODO(), &datapb.GcControlRequest{ + Command: datapb.GcCommand_Pause, + Params: []*commonpb.KeyValuePair{ + {Key: "duration", Value: "not_int"}, + }, + }) + s.Nil(err) + s.False(merr.Ok(resp)) + + resp, err = s.server.GcControl(context.TODO(), &datapb.GcControlRequest{ + Command: datapb.GcCommand_Pause, + Params: []*commonpb.KeyValuePair{ + {Key: "duration", Value: "60"}, + }, + }) + s.Nil(err) + s.True(merr.Ok(resp)) +} + +func (s *GcControlServiceSuite) TestResume() { + resp, err := s.server.GcControl(context.TODO(), &datapb.GcControlRequest{ + Command: datapb.GcCommand_Resume, + }) + s.Nil(err) + s.True(merr.Ok(resp)) +} + +func (s *GcControlServiceSuite) TestTimeoutCtx() { + s.server.garbageCollector.close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + resp, err := s.server.GcControl(ctx, &datapb.GcControlRequest{ + Command: datapb.GcCommand_Resume, + }) + s.Nil(err) + s.False(merr.Ok(resp)) + + resp, err = s.server.GcControl(ctx, &datapb.GcControlRequest{ + Command: datapb.GcCommand_Pause, + Params: []*commonpb.KeyValuePair{ + {Key: "duration", Value: "60"}, + }, + }) + s.Nil(err) + s.False(merr.Ok(resp)) +} + +func TestGcControlService(t *testing.T) { + suite.Run(t, new(GcControlServiceSuite)) +} diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index 111868369a5b3..c52fd3b8232a9 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -618,3 +618,9 @@ func (c *Client) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDat return client.ReportDataNodeTtMsgs(ctx, req) }) } + +func (c *Client) GcControl(ctx context.Context, req *datapb.GcControlRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*commonpb.Status, error) { + return client.GcControl(ctx, req) + }) +} diff --git a/internal/distributed/datacoord/client/client_test.go b/internal/distributed/datacoord/client/client_test.go index f051757e0b867..c4f13539a29ec 100644 --- a/internal/distributed/datacoord/client/client_test.go +++ b/internal/distributed/datacoord/client/client_test.go @@ -1823,3 +1823,40 @@ func Test_ReportDataNodeTtMsgs(t *testing.T) { _, err = client.ReportDataNodeTtMsgs(ctx, &datapb.ReportDataNodeTtMsgsRequest{}) assert.ErrorIs(t, err, context.DeadlineExceeded) } + +func Test_GcControl(t *testing.T) { + paramtable.Init() + + ctx := context.Background() + client, err := NewClient(ctx) + assert.NoError(t, err) + assert.NotNil(t, client) + defer client.Close() + + mockProxy := mocks.NewMockDataCoordClient(t) + mockGrpcClient := mocks.NewMockGrpcClient[datapb.DataCoordClient](t) + mockGrpcClient.EXPECT().Close().Return(nil) + mockGrpcClient.EXPECT().ReCall(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, f func(datapb.DataCoordClient) (interface{}, error)) (interface{}, error) { + return f(mockProxy) + }) + client.grpcClient = mockGrpcClient + + // test success + mockProxy.EXPECT().GcControl(mock.Anything, mock.Anything).Return(merr.Success(), nil) + _, err = client.GcControl(ctx, &datapb.GcControlRequest{}) + assert.Nil(t, err) + + // test return error code + mockProxy.ExpectedCalls = nil + mockProxy.EXPECT().GcControl(mock.Anything, mock.Anything).Return(merr.Status(err), nil) + + _, err = client.GcControl(ctx, &datapb.GcControlRequest{}) + assert.Nil(t, err) + + // test ctx done + ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + time.Sleep(20 * time.Millisecond) + _, err = client.GcControl(ctx, &datapb.GcControlRequest{}) + assert.ErrorIs(t, err, context.DeadlineExceeded) +} diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index 8d2e29973ceb0..fc8d21ecd094f 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -474,3 +474,7 @@ func (s *Server) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetInde func (s *Server) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) { return s.dataCoord.ReportDataNodeTtMsgs(ctx, req) } + +func (s *Server) GcControl(ctx context.Context, req *datapb.GcControlRequest) (*commonpb.Status, error) { + return s.dataCoord.GcControl(ctx, req) +} diff --git a/internal/distributed/datacoord/service_test.go b/internal/distributed/datacoord/service_test.go index 15322018083ac..f7ea6bcc095cc 100644 --- a/internal/distributed/datacoord/service_test.go +++ b/internal/distributed/datacoord/service_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/tikv/client-go/v2/txnkv" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -323,6 +324,13 @@ func Test_NewServer(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, ret) }) + + t.Run("GcControl", func(t *testing.T) { + mockDataCoord.EXPECT().GcControl(mock.Anything, mock.Anything).Return(&commonpb.Status{}, nil) + ret, err := server.GcControl(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, ret) + }) } func Test_Run(t *testing.T) { diff --git a/internal/http/server.go b/internal/http/server.go index cb5e50521b822..f99a481001dee 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -41,6 +41,14 @@ var ( server *http.Server ) +// Provide alias for native http package +// avoiding import alias when using http package + +type ( + ResponseWriter = http.ResponseWriter + Request = http.Request +) + type Handler struct { Path string HandlerFunc http.HandlerFunc diff --git a/internal/mocks/mock_datacoord.go b/internal/mocks/mock_datacoord.go index ebe92e4976bbd..abd562a0771bf 100644 --- a/internal/mocks/mock_datacoord.go +++ b/internal/mocks/mock_datacoord.go @@ -531,6 +531,61 @@ func (_c *MockDataCoord_GcConfirm_Call) RunAndReturn(run func(context.Context, * return _c } +// GcControl provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GcControl(_a0 context.Context, _a1 *datapb.GcControlRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GcControlRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GcControlRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GcControlRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoord_GcControl_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GcControl' +type MockDataCoord_GcControl_Call struct { + *mock.Call +} + +// GcControl is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *datapb.GcControlRequest +func (_e *MockDataCoord_Expecter) GcControl(_a0 interface{}, _a1 interface{}) *MockDataCoord_GcControl_Call { + return &MockDataCoord_GcControl_Call{Call: _e.mock.On("GcControl", _a0, _a1)} +} + +func (_c *MockDataCoord_GcControl_Call) Run(run func(_a0 context.Context, _a1 *datapb.GcControlRequest)) *MockDataCoord_GcControl_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.GcControlRequest)) + }) + return _c +} + +func (_c *MockDataCoord_GcControl_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoord_GcControl_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoord_GcControl_Call) RunAndReturn(run func(context.Context, *datapb.GcControlRequest) (*commonpb.Status, error)) *MockDataCoord_GcControl_Call { + _c.Call.Return(run) + return _c +} + // GetCollectionStatistics provides a mock function with given fields: _a0, _a1 func (_m *MockDataCoord) GetCollectionStatistics(_a0 context.Context, _a1 *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) { ret := _m.Called(_a0, _a1) diff --git a/internal/mocks/mock_datacoord_client.go b/internal/mocks/mock_datacoord_client.go index ab7b31ab64eec..8913a61e53f07 100644 --- a/internal/mocks/mock_datacoord_client.go +++ b/internal/mocks/mock_datacoord_client.go @@ -704,6 +704,76 @@ func (_c *MockDataCoordClient_GcConfirm_Call) RunAndReturn(run func(context.Cont return _c } +// GcControl provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GcControl(ctx context.Context, in *datapb.GcControlRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GcControlRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GcControlRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GcControlRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GcControl_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GcControl' +type MockDataCoordClient_GcControl_Call struct { + *mock.Call +} + +// GcControl is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GcControlRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GcControl(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GcControl_Call { + return &MockDataCoordClient_GcControl_Call{Call: _e.mock.On("GcControl", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GcControl_Call) Run(run func(ctx context.Context, in *datapb.GcControlRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GcControl_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GcControlRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GcControl_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_GcControl_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GcControl_Call) RunAndReturn(run func(context.Context, *datapb.GcControlRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_GcControl_Call { + _c.Call.Return(run) + return _c +} + // GetCollectionStatistics provides a mock function with given fields: ctx, in, opts func (_m *MockDataCoordClient) GetCollectionStatistics(ctx context.Context, in *datapb.GetCollectionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error) { _va := make([]interface{}, len(opts)) diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index b70ab1b761744..c896d32f2e6e5 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -91,6 +91,8 @@ service DataCoord { rpc GcConfirm(GcConfirmRequest) returns (GcConfirmResponse) {} rpc ReportDataNodeTtMsgs(ReportDataNodeTtMsgsRequest) returns (common.Status) {} + + rpc GcControl(GcControlRequest) returns(common.Status){} } service DataNode { @@ -865,3 +867,15 @@ message ImportTaskV2 { string reason = 8; repeated ImportFile files = 9; } + +enum GcCommand { + _ = 0; + Pause = 1; + Resume = 2; +} + +message GcControlRequest { + common.MsgBase base = 1; + GcCommand command = 2; + repeated common.KeyValuePair params = 3; +} \ No newline at end of file diff --git a/internal/proxy/management.go b/internal/proxy/management.go new file mode 100644 index 0000000000000..92aef6c600288 --- /dev/null +++ b/internal/proxy/management.go @@ -0,0 +1,93 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "fmt" + "net/http" + "sync" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + management "github.com/milvus-io/milvus/internal/http" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" +) + +// this file contains proxy management restful API handler + +const ( + mgrRouteGcPause = `/management/datacoord/garbage_collection/pause` + mgrRouteGcResume = `/management/datacoord/garbage_collection/resume` +) + +var mgrRouteRegisterOnce sync.Once + +func RegisterMgrRoute(proxy *Proxy) { + mgrRouteRegisterOnce.Do(func() { + management.Register(&management.Handler{ + Path: mgrRouteGcPause, + HandlerFunc: proxy.PauseDatacoordGC, + }) + management.Register(&management.Handler{ + Path: mgrRouteGcResume, + HandlerFunc: proxy.ResumeDatacoordGC, + }) + }) +} + +func (node *Proxy) PauseDatacoordGC(w http.ResponseWriter, req *http.Request) { + pauseSeconds := req.URL.Query().Get("pause_seconds") + + resp, err := node.dataCoord.GcControl(req.Context(), &datapb.GcControlRequest{ + Base: commonpbutil.NewMsgBase(), + Command: datapb.GcCommand_Pause, + Params: []*commonpb.KeyValuePair{ + {Key: "duration", Value: pauseSeconds}, + }, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to pause garbage collection, %s"}`, err.Error()))) + return + } + if resp.GetErrorCode() != commonpb.ErrorCode_Success { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to pause garbage collection, %s"}`, resp.GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"msg": "OK"}`)) +} + +func (node *Proxy) ResumeDatacoordGC(w http.ResponseWriter, req *http.Request) { + resp, err := node.dataCoord.GcControl(req.Context(), &datapb.GcControlRequest{ + Base: commonpbutil.NewMsgBase(), + Command: datapb.GcCommand_Resume, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to pause garbage collection, %s"}`, err.Error()))) + return + } + if resp.GetErrorCode() != commonpb.ErrorCode_Success { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to pause garbage collection, %s"}`, resp.GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"msg": "OK"}`)) +} diff --git a/internal/proxy/management_test.go b/internal/proxy/management_test.go new file mode 100644 index 0000000000000..56654fedaf352 --- /dev/null +++ b/internal/proxy/management_test.go @@ -0,0 +1,163 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/datapb" +) + +type ProxyManagementSuite struct { + suite.Suite + + datacoord *mocks.MockDataCoordClient + proxy *Proxy +} + +func (s *ProxyManagementSuite) SetupTest() { + s.datacoord = mocks.NewMockDataCoordClient(s.T()) + s.proxy = &Proxy{ + dataCoord: s.datacoord, + } +} + +func (s *ProxyManagementSuite) TearDownTest() { + s.datacoord.AssertExpectations(s.T()) +} + +func (s *ProxyManagementSuite) TestPauseDataCoordGC() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + s.datacoord.EXPECT().GcControl(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, req *datapb.GcControlRequest, options ...grpc.CallOption) (*commonpb.Status, error) { + s.Equal(datapb.GcCommand_Pause, req.GetCommand()) + return &commonpb.Status{}, nil + }) + + req, err := http.NewRequest(http.MethodGet, mgrRouteGcPause+"?pause_seconds=60", nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.PauseDatacoordGC(recorder, req) + + s.Equal(http.StatusOK, recorder.Code) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + s.datacoord.EXPECT().GcControl(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, req *datapb.GcControlRequest, options ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, errors.New("mock") + }) + + req, err := http.NewRequest(http.MethodGet, mgrRouteGcPause+"?pause_seconds=60", nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.PauseDatacoordGC(recorder, req) + + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + s.datacoord.EXPECT().GcControl(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, req *datapb.GcControlRequest, options ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mocked", + }, nil + }) + + req, err := http.NewRequest(http.MethodGet, mgrRouteGcPause+"?pause_seconds=60", nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.PauseDatacoordGC(recorder, req) + + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func (s *ProxyManagementSuite) TestResumeDatacoordGC() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + s.datacoord.EXPECT().GcControl(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, req *datapb.GcControlRequest, options ...grpc.CallOption) (*commonpb.Status, error) { + s.Equal(datapb.GcCommand_Resume, req.GetCommand()) + return &commonpb.Status{}, nil + }) + + req, err := http.NewRequest(http.MethodGet, mgrRouteGcResume, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.ResumeDatacoordGC(recorder, req) + + s.Equal(http.StatusOK, recorder.Code) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + s.datacoord.EXPECT().GcControl(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, req *datapb.GcControlRequest, options ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, errors.New("mock") + }) + + req, err := http.NewRequest(http.MethodGet, mgrRouteGcResume, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.ResumeDatacoordGC(recorder, req) + + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + s.datacoord.EXPECT().GcControl(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, req *datapb.GcControlRequest, options ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mocked", + }, nil + }) + + req, err := http.NewRequest(http.MethodGet, mgrRouteGcResume, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.ResumeDatacoordGC(recorder, req) + + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func TestProxyManagement(t *testing.T) { + suite.Run(t, new(ProxyManagementSuite)) +} From 6efb7afd3fc0fa446fd25bb11d751caf29426200 Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Thu, 14 Dec 2023 19:38:45 +0800 Subject: [PATCH 06/14] test: add more request type checker for test (#29210) add more request type checker for test * partition * database * upsert Signed-off-by: zhuwenxing --- tests/python_client/chaos/checker.py | 521 ++++++++++++++---- tests/python_client/chaos/test_chaos.py | 16 +- .../chaos/test_chaos_memory_stress.py | 18 +- .../chaos/test_load_with_checker.py | 12 +- .../testcases/test_all_checker_operation.py | 133 +++++ .../testcases/test_concurrent_operation.py | 2 + .../test_single_request_operation.py | 14 +- ...le_request_operation_for_rolling_update.py | 18 +- ...st_single_request_operation_for_standby.py | 10 +- .../testcases/test_verify_all_collections.py | 10 +- .../loadbalance/test_auto_load_balance.py | 10 +- 11 files changed, 619 insertions(+), 145 deletions(-) create mode 100644 tests/python_client/chaos/testcases/test_all_checker_operation.py diff --git a/tests/python_client/chaos/checker.py b/tests/python_client/chaos/checker.py index 641bf74e8f683..26c63d4f1a4d1 100644 --- a/tests/python_client/chaos/checker.py +++ b/tests/python_client/chaos/checker.py @@ -12,7 +12,9 @@ from prettytable import PrettyTable import functools from time import sleep +from base.database_wrapper import ApiDatabaseWrapper from base.collection_wrapper import ApiCollectionWrapper +from base.partition_wrapper import ApiPartitionWrapper from base.utility_wrapper import ApiUtilityWrapper from common import common_func as cf from common import common_type as ct @@ -195,15 +197,30 @@ def show_result_table(self): class Op(Enum): - create = 'create' + create = 'create' # short name for create collection + create_db = 'create_db' + create_collection = 'create_collection' + create_partition = 'create_partition' insert = 'insert' + upsert = 'upsert' flush = 'flush' index = 'index' + create_index = 'create_index' + drop_index = 'drop_index' + load = 'load' + load_collection = 'load_collection' + load_partition = 'load_partition' + release = 'release' + release_collection = 'release_collection' + release_partition = 'release_partition' search = 'search' query = 'query' delete = 'delete' compact = 'compact' - drop = 'drop' + drop = 'drop' # short name for drop collection + drop_db = 'drop_db' + drop_collection = 'drop_collection' + drop_partition = 'drop_partition' load_balance = 'load_balance' bulk_insert = 'bulk_insert' unknown = 'unknown' @@ -288,7 +305,8 @@ class Checker: b. count operations and success rate """ - def __init__(self, collection_name=None, shards_num=2, dim=ct.default_dim, insert_data=True, schema=None): + def __init__(self, collection_name=None, partition_name=None, shards_num=2, dim=ct.default_dim, insert_data=True, + schema=None): self.recovery_time = 0 self._succ = 0 self._fail = 0 @@ -299,11 +317,16 @@ def __init__(self, collection_name=None, shards_num=2, dim=ct.default_dim, inser self.files = [] self.ms = MilvusSys() self.bucket_name = self.ms.index_nodes[0]["infos"]["system_configurations"]["minio_bucket_name"] + self.db_wrap = ApiDatabaseWrapper() self.c_wrap = ApiCollectionWrapper() + self.p_wrap = ApiPartitionWrapper() self.utility_wrap = ApiUtilityWrapper() c_name = collection_name if collection_name is not None else cf.gen_unique_str( 'Checker_') self.c_name = c_name + p_name = partition_name if partition_name is not None else "_default" + self.p_name = p_name + self.p_names = [self.p_name] if partition_name is not None else None schema = cf.gen_default_collection_schema(dim=dim) if schema is None else schema self.schema = schema self.dim = cf.get_dim_by_schema(schema=schema) @@ -314,17 +337,27 @@ def __init__(self, collection_name=None, shards_num=2, dim=ct.default_dim, inser shards_num=shards_num, timeout=timeout, enable_traceback=enable_traceback) + self.p_wrap.init_partition(self.c_name, self.p_name) if insert_data: log.info(f"collection {c_name} created, start to insert data") t0 = time.perf_counter() self.c_wrap.insert( data=cf.get_column_data_by_schema(nb=constants.ENTITIES_FOR_SEARCH, schema=schema, start=0), + partition_name=self.p_name, timeout=timeout, enable_traceback=enable_traceback) log.info(f"insert data for collection {c_name} cost {time.perf_counter() - t0}s") self.initial_entities = self.c_wrap.num_entities # do as a flush + def insert_data(self, nb=constants.ENTITIES_FOR_SEARCH, partition_name=None): + partition_name = self.p_name if partition_name is None else partition_name + self.c_wrap.insert( + data=cf.get_column_data_by_schema(nb=nb, schema=self.schema, start=0), + partition_name=partition_name, + timeout=timeout, + enable_traceback=enable_traceback) + def total(self): return self._succ + self._fail @@ -407,6 +440,140 @@ def do_bulk_insert(self): return task_ids, completed +class CollectionLoadChecker(Checker): + """check collection load operations in a dependent thread""" + + def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema=None, ): + self.replica_number = replica_number + if collection_name is None: + collection_name = cf.gen_unique_str("LoadChecker_") + super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema) + self.c_wrap.create_index(self.float_vector_field_name, + constants.DEFAULT_INDEX_PARAM, + index_name=cf.gen_unique_str('index_'), + timeout=timeout, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing) + + @trace() + def load_collection(self): + res, result = self.c_wrap.load(replica_number=self.replica_number) + return res, result + + @exception_handler() + def run_task(self): + res, result = self.load_collection() + if result: + self.c_wrap.release() + return res, result + + def keep_running(self): + while self._keep_running: + self.run_task() + sleep(constants.WAIT_PER_OP) + + +class CollectionReleaseChecker(Checker): + """check collection release operations in a dependent thread""" + + def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema=None, ): + self.replica_number = replica_number + if collection_name is None: + collection_name = cf.gen_unique_str("LoadChecker_") + super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema) + self.c_wrap.create_index(self.float_vector_field_name, + constants.DEFAULT_INDEX_PARAM, + index_name=cf.gen_unique_str('index_'), + timeout=timeout, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing) + self.c_wrap.load(replica_number=self.replica_number) + + @trace() + def release_collection(self): + res, result = self.c_wrap.release() + return res, result + + @exception_handler() + def run_task(self): + res, result = self.release_collection() + if result: + self.c_wrap.release() + return res, result + + def keep_running(self): + while self._keep_running: + self.run_task() + sleep(constants.WAIT_PER_OP) + + +class PartitionLoadChecker(Checker): + """check partition load operations in a dependent thread""" + + def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema=None, ): + self.replica_number = replica_number + if collection_name is None: + collection_name = cf.gen_unique_str("LoadChecker_") + super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema) + self.c_wrap.create_index(self.float_vector_field_name, + constants.DEFAULT_INDEX_PARAM, + index_name=cf.gen_unique_str('index_'), + timeout=timeout, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing) + + @trace() + def load_partition(self): + res, result = self.p_wrap.load(replica_number=self.replica_number) + return res, result + + @exception_handler() + def run_task(self): + res, result = self.load_partition() + if result: + self.p_wrap.release() + return res, result + + def keep_running(self): + while self._keep_running: + self.run_task() + sleep(constants.WAIT_PER_OP) + + +class PartitionReleaseChecker(Checker): + """check partition release operations in a dependent thread""" + + def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema=None, ): + self.replica_number = replica_number + if collection_name is None: + collection_name = cf.gen_unique_str("LoadChecker_") + super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema) + self.c_wrap.create_index(self.float_vector_field_name, + constants.DEFAULT_INDEX_PARAM, + index_name=cf.gen_unique_str('index_'), + timeout=timeout, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing) + self.p_wrap.load(replica_number=self.replica_number) + + @trace() + def release_partition(self): + res, result = self.p_wrap.release() + return res, result + + @exception_handler() + def run_task(self): + res, result = self.release_partition() + if result: + self.p_wrap.load(replica_number=self.replica_number) + return res, result + + def keep_running(self): + while self._keep_running: + self.run_task() + sleep(constants.WAIT_PER_OP) + + class SearchChecker(Checker): """check search operations in a dependent thread""" @@ -422,6 +589,7 @@ def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema= check_task=CheckTasks.check_nothing) # do load before search self.c_wrap.load(replica_number=replica_number) + self.insert_data() @trace() def search(self): @@ -430,6 +598,7 @@ def search(self): anns_field=self.float_vector_field_name, param=constants.DEFAULT_SEARCH_PARAM, limit=1, + partition_names=self.p_names, timeout=search_timeout, check_task=CheckTasks.check_nothing ) @@ -525,7 +694,7 @@ def keep_running(self): class InsertChecker(Checker): - """check flush operations in a dependent thread""" + """check insert operations in a dependent thread""" def __init__(self, collection_name=None, flush=False, shards_num=2, schema=None): if collection_name is None: @@ -540,7 +709,7 @@ def __init__(self, collection_name=None, flush=False, shards_num=2, schema=None) self.file_name = f"/tmp/ci_logs/insert_data_{uuid.uuid4()}.parquet" @trace() - def insert(self): + def insert_entities(self): data = cf.get_column_data_by_schema(nb=constants.DELTA_PER_INS, schema=self.schema) ts_data = [] for i in range(constants.DELTA_PER_INS): @@ -551,6 +720,7 @@ def insert(self): data[0] = ts_data # set timestamp (ms) as int64 log.debug(f"insert data: {ts_data}") res, result = self.c_wrap.insert(data=data, + partition_names=self.p_names, timeout=timeout, enable_traceback=enable_traceback, check_task=CheckTasks.check_nothing) @@ -561,7 +731,7 @@ def insert(self): @exception_handler() def run_task(self): - res, result = self.insert() + res, result = self.insert_entities() return res, result def keep_running(self): @@ -599,8 +769,36 @@ def verify_data_completeness(self): pytest.assume(set(data_in_server) == set(data_in_client)) -class CreateChecker(Checker): - """check create operations in a dependent thread""" +class UpsertChecker(Checker): + """check upsert operations in a dependent thread""" + + def __init__(self, collection_name=None, flush=False, shards_num=2, schema=None): + if collection_name is None: + collection_name = cf.gen_unique_str("InsertChecker_") + super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema) + + @trace() + def upsert_entities(self): + data = cf.get_column_data_by_schema(nb=constants.DELTA_PER_INS, schema=self.schema) + res, result = self.c_wrap.upsert(data=data, + timeout=timeout, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing) + return res, result + + @exception_handler() + def run_task(self): + res, result = self.upsert_entities() + return res, result + + def keep_running(self): + while self._keep_running: + self.run_task() + sleep(constants.WAIT_PER_OP / 10) + + +class CollectionCreateChecker(Checker): + """check collection create operations in a dependent thread""" def __init__(self, collection_name=None, schema=None): if collection_name is None: @@ -630,8 +828,180 @@ def keep_running(self): sleep(constants.WAIT_PER_OP) -class IndexChecker(Checker): - """check Insert operations in a dependent thread""" +class CollectionDropChecker(Checker): + """check collection drop operations in a dependent thread""" + + def __init__(self, collection_name=None, schema=None): + if collection_name is None: + collection_name = cf.gen_unique_str("DropChecker_") + super().__init__(collection_name=collection_name, schema=schema) + self.collection_pool = [] + self.gen_collection_pool(schema=self.schema) + + def gen_collection_pool(self, pool_size=50, schema=None): + for i in range(pool_size): + collection_name = cf.gen_unique_str("DropChecker_") + res, result = self.c_wrap.init_collection(name=collection_name, schema=schema) + if result: + self.collection_pool.append(collection_name) + + @trace() + def drop_collection(self): + res, result = self.c_wrap.drop() + if result: + self.collection_pool.remove(self.c_wrap.name) + return res, result + + @exception_handler() + def run_task(self): + res, result = self.drop_collection() + return res, result + + def keep_running(self): + while self._keep_running: + res, result = self.run_task() + if result: + try: + if len(self.collection_pool) <= 10: + self.gen_collection_pool(schema=self.schema) + except Exception as e: + log.error(f"Failed to generate collection pool: {e}") + try: + c_name = self.collection_pool[0] + self.c_wrap.init_collection(name=c_name) + except Exception as e: + log.error(f"Failed to init new collection: {e}") + sleep(constants.WAIT_PER_OP) + + +class PartitionCreateChecker(Checker): + """check partition create operations in a dependent thread""" + + def __init__(self, collection_name=None, schema=None, partition_name=None): + if collection_name is None: + collection_name = cf.gen_unique_str("PartitionCreateChecker_") + super().__init__(collection_name=collection_name, schema=schema, partition_name=partition_name) + + @trace() + def create_partition(self): + res, result = self.p_wrap.init_partition(collection=self.c_name, + name=cf.gen_unique_str("PartitionCreateChecker_"), + timeout=timeout, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing + ) + return res, result + + @exception_handler() + def run_task(self): + res, result = self.create_partition() + if result: + self.p_wrap.drop(timeout=timeout) + return res, result + + def keep_running(self): + while self._keep_running: + self.run_task() + sleep(constants.WAIT_PER_OP) + + +class PartitionDropChecker(Checker): + """check partition drop operations in a dependent thread""" + + def __init__(self, collection_name=None, schema=None, partition_name=None): + if collection_name is None: + collection_name = cf.gen_unique_str("PartitionDropChecker_") + super().__init__(collection_name=collection_name, schema=schema, partition_name=partition_name) + self.p_wrap.init_partition(collection=self.c_name, + name=cf.gen_unique_str("PartitionDropChecker_"), + timeout=timeout, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing + ) + + @trace() + def drop_partition(self): + res, result = self.p_wrap.drop() + return res, result + + @exception_handler() + def run_task(self): + res, result = self.drop_partition() + if result: + self.p_wrap.init_partition(collection=self.c_name, + name=cf.gen_unique_str("PartitionDropChecker_"), + timeout=timeout, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing + ) + return res, result + + def keep_running(self): + while self._keep_running: + self.run_task() + sleep(constants.WAIT_PER_OP) + + +class DatabaseCreateChecker(Checker): + """check create database operations in a dependent thread""" + + def __init__(self, collection_name=None, schema=None): + if collection_name is None: + collection_name = cf.gen_unique_str("DatabaseChecker_") + super().__init__(collection_name=collection_name, schema=schema) + self.db_name = None + + @trace() + def init_db(self): + db_name = cf.gen_unique_str("db_") + res, result = self.db_wrap.create_database(db_name) + self.db_name = db_name + return res, result + + @exception_handler() + def run_task(self): + res, result = self.init_db() + if result: + self.db_wrap.drop_database(self.db_name) + return res, result + + def keep_running(self): + while self._keep_running: + self.run_task() + sleep(constants.WAIT_PER_OP) + + +class DatabaseDropChecker(Checker): + """check drop database operations in a dependent thread""" + + def __init__(self, collection_name=None, schema=None): + if collection_name is None: + collection_name = cf.gen_unique_str("DatabaseChecker_") + super().__init__(collection_name=collection_name, schema=schema) + self.db_name = cf.gen_unique_str("db_") + self.db_wrap.create_database(self.db_name) + + @trace() + def drop_db(self): + res, result = self.db_wrap.drop_database(self.db_name) + return res, result + + @exception_handler() + def run_task(self): + res, result = self.drop_db() + if result: + self.db_name = cf.gen_unique_str("db_") + self.db_wrap.create_database(self.db_name) + return res, result + + def keep_running(self): + while self._keep_running: + self.run_task() + sleep(constants.WAIT_PER_OP) + + +class IndexCreateChecker(Checker): + """check index create operations in a dependent thread""" def __init__(self, collection_name=None, schema=None): if collection_name is None: @@ -666,52 +1036,54 @@ def keep_running(self): sleep(constants.WAIT_PER_OP * 6) -class QueryChecker(Checker): - """check query operations in a dependent thread""" +class IndexDropChecker(Checker): + """check index drop operations in a dependent thread""" - def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema=None): + def __init__(self, collection_name=None, schema=None): if collection_name is None: - collection_name = cf.gen_unique_str("QueryChecker_") - super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema) - res, result = self.c_wrap.create_index(self.float_vector_field_name, - constants.DEFAULT_INDEX_PARAM, - index_name=cf.gen_unique_str( - 'index_'), - timeout=timeout, - enable_traceback=enable_traceback, - check_task=CheckTasks.check_nothing) - self.c_wrap.load(replica_number=replica_number) # do load before query - self.term_expr = None + collection_name = cf.gen_unique_str("IndexChecker_") + super().__init__(collection_name=collection_name, schema=schema) + self.index_name = cf.gen_unique_str('index_') + for i in range(5): + self.c_wrap.insert(data=cf.get_column_data_by_schema(nb=constants.ENTITIES_FOR_SEARCH, schema=self.schema), + timeout=timeout, enable_traceback=enable_traceback) + # do as a flush before indexing + log.debug(f"Index ready entities: {self.c_wrap.num_entities}") + self.c_wrap.create_index(self.float_vector_field_name, + constants.DEFAULT_INDEX_PARAM, + index_name=self.index_name, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing) @trace() - def query(self): - res, result = self.c_wrap.query(self.term_expr, timeout=query_timeout, - check_task=CheckTasks.check_nothing) + def drop_index(self): + res, result = self.c_wrap.drop_index(timeout=timeout) return res, result @exception_handler() def run_task(self): - int_values = [] - for _ in range(5): - int_values.append(randint(0, constants.ENTITIES_FOR_SEARCH)) - self.term_expr = f'{self.int64_field_name} in {int_values}' - res, result = self.query() + res, result = self.drop_index() + if result: + self.c_wrap.create_index(self.float_vector_field_name, + constants.DEFAULT_INDEX_PARAM, + index_name=self.index_name, + enable_traceback=enable_traceback, + check_task=CheckTasks.check_nothing) return res, result def keep_running(self): while self._keep_running: self.run_task() - sleep(constants.WAIT_PER_OP / 10) + sleep(constants.WAIT_PER_OP * 6) -class LoadChecker(Checker): - """check load operations in a dependent thread""" +class QueryChecker(Checker): + """check query operations in a dependent thread""" - def __init__(self, collection_name=None, replica_number=1, schema=None): + def __init__(self, collection_name=None, shards_num=2, replica_number=1, schema=None): if collection_name is None: - collection_name = cf.gen_unique_str("LoadChecker_") - super().__init__(collection_name=collection_name, schema=schema) - self.replica_number = replica_number + collection_name = cf.gen_unique_str("QueryChecker_") + super().__init__(collection_name=collection_name, shards_num=shards_num, schema=schema) res, result = self.c_wrap.create_index(self.float_vector_field_name, constants.DEFAULT_INDEX_PARAM, index_name=cf.gen_unique_str( @@ -719,17 +1091,23 @@ def __init__(self, collection_name=None, replica_number=1, schema=None): timeout=timeout, enable_traceback=enable_traceback, check_task=CheckTasks.check_nothing) + self.c_wrap.load(replica_number=replica_number) # do load before query + self.insert_data() + self.term_expr = None @trace() - def load(self): - res, result = self.c_wrap.load(replica_number=self.replica_number, timeout=timeout) + def query(self): + res, result = self.c_wrap.query(self.term_expr, timeout=query_timeout, + check_task=CheckTasks.check_nothing) return res, result @exception_handler() def run_task(self): - res, result = self.load() - if result: - self.c_wrap.release() + int_values = [] + for _ in range(5): + int_values.append(randint(0, constants.ENTITIES_FOR_SEARCH)) + self.term_expr = f'{self.int64_field_name} in {int_values}' + res, result = self.query() return res, result def keep_running(self): @@ -753,6 +1131,7 @@ def __init__(self, collection_name=None, schema=None): enable_traceback=enable_traceback, check_task=CheckTasks.check_nothing) self.c_wrap.load() # load before query + self.insert_data() term_expr = f'{self.int64_field_name} > 0' res, _ = self.c_wrap.query(term_expr, output_fields=[ self.int64_field_name]) @@ -760,7 +1139,7 @@ def __init__(self, collection_name=None, schema=None): self.expr = None @trace() - def delete(self): + def delete_entities(self): res, result = self.c_wrap.delete(expr=self.expr, timeout=timeout) return res, result @@ -768,7 +1147,7 @@ def delete(self): def run_task(self): delete_ids = self.ids.pop() self.expr = f'{self.int64_field_name} in {[delete_ids]}' - res, result = self.delete() + res, result = self.delete_entities() return res, result def keep_running(self): @@ -812,54 +1191,8 @@ def keep_running(self): sleep(constants.WAIT_PER_OP / 10) -class DropChecker(Checker): - """check drop operations in a dependent thread""" - - def __init__(self, collection_name=None, schema=None): - if collection_name is None: - collection_name = cf.gen_unique_str("DropChecker_") - super().__init__(collection_name=collection_name, schema=schema) - self.collection_pool = [] - self.gen_collection_pool(schema=self.schema) - - def gen_collection_pool(self, pool_size=50, schema=None): - for i in range(pool_size): - collection_name = cf.gen_unique_str("DropChecker_") - res, result = self.c_wrap.init_collection(name=collection_name, schema=schema) - if result: - self.collection_pool.append(collection_name) - - @trace() - def drop(self): - res, result = self.c_wrap.drop() - if result: - self.collection_pool.remove(self.c_wrap.name) - return res, result - - @exception_handler() - def run_task(self): - res, result = self.drop() - return res, result - - def keep_running(self): - while self._keep_running: - res, result = self.run_task() - if result: - try: - if len(self.collection_pool) <= 10: - self.gen_collection_pool(schema=self.schema) - except Exception as e: - log.error(f"Failed to generate collection pool: {e}") - try: - c_name = self.collection_pool[0] - self.c_wrap.init_collection(name=c_name) - except Exception as e: - log.error(f"Failed to init new collection: {e}") - sleep(constants.WAIT_PER_OP) - - class LoadBalanceChecker(Checker): - """check loadbalance operations in a dependent thread""" + """check load balance operations in a dependent thread""" def __init__(self, collection_name=None, schema=None): if collection_name is None: @@ -912,7 +1245,7 @@ def keep_running(self): class BulkInsertChecker(Checker): - """check bulk load operations in a dependent thread""" + """check bulk insert operations in a dependent thread""" def __init__(self, collection_name=None, files=[], use_one_collection=False, dim=ct.default_dim, schema=None, insert_data=False): diff --git a/tests/python_client/chaos/test_chaos.py b/tests/python_client/chaos/test_chaos.py index e4363f59cfc34..867460da36972 100644 --- a/tests/python_client/chaos/test_chaos.py +++ b/tests/python_client/chaos/test_chaos.py @@ -6,8 +6,8 @@ from time import sleep from pymilvus import connections -from chaos.checker import (CreateChecker, InsertChecker, FlushChecker, - SearchChecker, QueryChecker, IndexChecker, DeleteChecker, Op) +from chaos.checker import (CollectionCreateChecker, InsertChecker, FlushChecker, + SearchChecker, QueryChecker, IndexCreateChecker, DeleteChecker, Op) from common.cus_resource_opts import CustomResourceOperations as CusResource from utils.util_log import test_log as log from utils.util_k8s import wait_pods_ready, get_pod_list @@ -20,11 +20,11 @@ def check_cluster_nodes(chaos_config): - # if all pods will be effected, the expect is all fail. + # if all pods will be effected, the expect is all fail. # Even though the replicas is greater than 1, it can not provide HA, so cluster_nodes is set as 1 for this situation. if "all" in chaos_config["metadata"]["name"]: return 1 - + selector = findkeys(chaos_config, "selector") selector = list(selector) log.info(f"chaos target selector: {selector}") @@ -93,7 +93,7 @@ class TestChaos(TestChaosBase): def connection(self, host, port): connections.add_connection(default={"host": host, "port": port}) connections.connect(alias='default') - + if connections.has_connection("default") is False: raise Exception("no connections") self.host = host @@ -102,10 +102,10 @@ def connection(self, host, port): @pytest.fixture(scope="function", autouse=True) def init_health_checkers(self): checkers = { - Op.create: CreateChecker(), + Op.create: CollectionCreateChecker(), Op.insert: InsertChecker(), Op.flush: FlushChecker(), - Op.index: IndexChecker(), + Op.index: IndexCreateChecker(), Op.search: SearchChecker(), Op.query: QueryChecker(), Op.delete: DeleteChecker() @@ -244,4 +244,4 @@ def test_chaos(self, chaos_yaml): # assert all expectations assert_expectations() - log.info("*********************Chaos Test Completed**********************") \ No newline at end of file + log.info("*********************Chaos Test Completed**********************") diff --git a/tests/python_client/chaos/test_chaos_memory_stress.py b/tests/python_client/chaos/test_chaos_memory_stress.py index 0d8f12f55b651..1c2782745784e 100644 --- a/tests/python_client/chaos/test_chaos_memory_stress.py +++ b/tests/python_client/chaos/test_chaos_memory_stress.py @@ -9,7 +9,7 @@ from pymilvus import connections from base.collection_wrapper import ApiCollectionWrapper from base.utility_wrapper import ApiUtilityWrapper -from chaos.checker import Op, CreateChecker, InsertFlushChecker, IndexChecker, SearchChecker, QueryChecker +from chaos.checker import Op, CollectionCreateChecker, InsertFlushChecker, IndexCreateChecker, SearchChecker, QueryChecker from common.cus_resource_opts import CustomResourceOperations as CusResource from common import common_func as cf from common import common_type as ct @@ -74,7 +74,7 @@ def test_chaos_memory_stress_querynode(self, connection, chaos_yaml): # wait memory stress sleep(constants.WAIT_PER_OP * 2) - # try to do release, load, query and serach in a duration time loop + # try to do release, load, query and search in a duration time loop try: start = time.time() while time.time() - start < eval(duration): @@ -215,10 +215,10 @@ def test_chaos_memory_stress_etcd(self, chaos_yaml): expected: Verify milvus operation succ rate """ mic_checkers = { - Op.create: CreateChecker(), + Op.create: CollectionCreateChecker(), Op.insert: InsertFlushChecker(), Op.flush: InsertFlushChecker(flush=True), - Op.index: IndexChecker(), + Op.index: IndexCreateChecker(), Op.search: SearchChecker(), Op.query: QueryChecker() } @@ -285,7 +285,7 @@ def prepare_collection(self, host, port): @pytest.mark.skip(reason="https://github.com/milvus-io/milvus/issues/16887") @pytest.mark.tags(CaseLabel.L3) - def test_memory_stress_replicas_befor_load(self, prepare_collection): + def test_memory_stress_replicas_before_load(self, prepare_collection): """ target: test querynode group load with insufficient memory method: 1.Limit querynode memory ? 2Gi @@ -353,7 +353,7 @@ def test_memory_stress_replicas_group_sufficient(self, prepare_collection, mode) def test_memory_stress_replicas_group_insufficient(self, prepare_collection, mode): """ target: test apply stress memory on different number querynodes and the group failed to load, - bacause of the memory is insufficient + because of the memory is insufficient method: 1.Limit querynodes memory 5Gi 2.Create collection and insert 1000,000 entities 3.Apply memory stress on querynodes and it's memory is not enough to load replicas @@ -529,7 +529,7 @@ def test_memory_stress_replicas_group_load_balance(self, prepare_collection): chaos_res.delete(metadata_name=chaos_config.get('metadata', None).get('name', None)) - # Verfiy auto load loadbalance + # Verify auto load loadbalance seg_info_after, _ = utility_w.get_query_segment_info(collection_w.name) seg_distribution_after = cf.get_segment_distribution(seg_info_after) segments_num_after = len(seg_distribution_after[chaos_querynode_id]["sealed"]) @@ -549,7 +549,7 @@ def test_memory_stress_replicas_cross_group_load_balance(self, prepare_collectio method: 1.Limit all querynodes memory 6Gi 2.Create and insert 1000,000 entities 3.Load collection with two replicas - 4.Apply memory stress on one grooup 80% + 4.Apply memory stress on one group 80% expected: Verify that load balancing across groups is not occurring """ collection_w = prepare_collection @@ -586,7 +586,7 @@ def test_memory_stress_replicas_cross_group_load_balance(self, prepare_collectio chaos_res.delete(metadata_name=chaos_config.get('metadata', None).get('name', None)) - # Verfiy auto load loadbalance + # Verify auto load loadbalance seg_info_after, _ = utility_w.get_query_segment_info(collection_w.name) seg_distribution_before = cf.get_segment_distribution(seg_info_before) seg_distribution_after = cf.get_segment_distribution(seg_info_after) diff --git a/tests/python_client/chaos/test_load_with_checker.py b/tests/python_client/chaos/test_load_with_checker.py index 419a38b42be58..724c467ba5181 100644 --- a/tests/python_client/chaos/test_load_with_checker.py +++ b/tests/python_client/chaos/test_load_with_checker.py @@ -4,15 +4,15 @@ from time import sleep from minio import Minio from pymilvus import connections -from chaos.checker import (CreateChecker, +from chaos.checker import (CollectionCreateChecker, InsertChecker, FlushChecker, SearchChecker, QueryChecker, - IndexChecker, + IndexCreateChecker, DeleteChecker, CompactChecker, - DropChecker, + CollectionDropChecker, LoadBalanceChecker, BulkInsertChecker, Op) @@ -56,15 +56,15 @@ def connection(self, host, port): def init_health_checkers(self): c_name = cf.gen_unique_str("Checker_") checkers = { - # Op.create: CreateChecker(collection_name=c_name), + # Op.create: CollectionCreateChecker(collection_name=c_name), # Op.insert: InsertChecker(collection_name=c_name), # Op.flush: FlushChecker(collection_name=c_name), # Op.query: QueryChecker(collection_name=c_name), # Op.search: SearchChecker(collection_name=c_name), # Op.delete: DeleteChecker(collection_name=c_name), # Op.compact: CompactChecker(collection_name=c_name), - # Op.index: IndexChecker(), - # Op.drop: DropChecker(), + # Op.index: IndexCreateChecker(), + # Op.drop: CollectionDropChecker(), # Op.bulk_insert: BulkInsertChecker(), Op.load_balance: LoadBalanceChecker() } diff --git a/tests/python_client/chaos/testcases/test_all_checker_operation.py b/tests/python_client/chaos/testcases/test_all_checker_operation.py new file mode 100644 index 0000000000000..1c00febbe5ac4 --- /dev/null +++ b/tests/python_client/chaos/testcases/test_all_checker_operation.py @@ -0,0 +1,133 @@ +import time + +import pytest +from time import sleep +from pymilvus import connections +from chaos.checker import ( + DatabaseCreateChecker, + DatabaseDropChecker, + CollectionCreateChecker, + CollectionDropChecker, + PartitionCreateChecker, + PartitionDropChecker, + CollectionLoadChecker, + CollectionReleaseChecker, + PartitionLoadChecker, + PartitionReleaseChecker, + IndexCreateChecker, + IndexDropChecker, + InsertChecker, + UpsertChecker, + DeleteChecker, + FlushChecker, + SearchChecker, + QueryChecker, + Op, + EventRecords, + ResultAnalyzer +) +from utils.util_log import test_log as log +from utils.util_k8s import wait_pods_ready, get_milvus_instance_name +from chaos import chaos_commons as cc +from common.common_type import CaseLabel +from common.milvus_sys import MilvusSys +from chaos.chaos_commons import assert_statistic +from chaos import constants +from delayed_assert import assert_expectations + + +class TestBase: + expect_create = constants.SUCC + expect_insert = constants.SUCC + expect_flush = constants.SUCC + expect_index = constants.SUCC + expect_search = constants.SUCC + expect_query = constants.SUCC + host = '127.0.0.1' + port = 19530 + _chaos_config = None + health_checkers = {} + + +class TestOperations(TestBase): + + @pytest.fixture(scope="function", autouse=True) + def connection(self, host, port, user, password, milvus_ns): + if user and password: + # log.info(f"connect to {host}:{port} with user {user} and password {password}") + connections.connect('default', host=host, port=port, user=user, password=password, secure=True) + else: + connections.connect('default', host=host, port=port) + if connections.has_connection("default") is False: + raise Exception("no connections") + log.info("connect to milvus successfully") + self.host = host + self.port = port + self.user = user + self.password = password + self.milvus_sys = MilvusSys(alias='default') + self.milvus_ns = milvus_ns + self.release_name = get_milvus_instance_name(self.milvus_ns, milvus_sys=self.milvus_sys) + + def init_health_checkers(self, collection_name=None): + c_name = collection_name + checkers = { + Op.create_db: DatabaseCreateChecker(), + Op.create_collection: CollectionCreateChecker(collection_name=c_name), + Op.create_partition: PartitionCreateChecker(collection_name=c_name), + Op.drop_db: DatabaseDropChecker(), + Op.drop_collection: CollectionDropChecker(collection_name=c_name), + Op.drop_partition: PartitionDropChecker(collection_name=c_name), + Op.load_collection: CollectionLoadChecker(collection_name=c_name), + Op.load_partition: PartitionLoadChecker(collection_name=c_name), + Op.release_collection: CollectionReleaseChecker(collection_name=c_name), + Op.release_partition: PartitionReleaseChecker(collection_name=c_name), + Op.insert: InsertChecker(collection_name=c_name), + Op.upsert: UpsertChecker(collection_name=c_name), + Op.flush: FlushChecker(collection_name=c_name), + Op.create_index: IndexCreateChecker(collection_name=c_name), + Op.drop_index: IndexDropChecker(collection_name=c_name), + Op.search: SearchChecker(collection_name=c_name), + Op.query: QueryChecker(collection_name=c_name), + Op.delete: DeleteChecker(collection_name=c_name), + Op.drop: CollectionDropChecker(collection_name=c_name) + } + self.health_checkers = checkers + + @pytest.mark.tags(CaseLabel.L3) + def test_operations(self, request_duration, is_check): + # start the monitor threads to check the milvus ops + log.info("*********************Test Start**********************") + log.info(connections.get_connection_addr('default')) + event_records = EventRecords() + c_name = None + event_records.insert("init_health_checkers", "start") + self.init_health_checkers(collection_name=c_name) + event_records.insert("init_health_checkers", "finished") + tasks = cc.start_monitor_threads(self.health_checkers) + log.info("*********************Load Start**********************") + # wait request_duration + request_duration = request_duration.replace("h", "*3600+").replace("m", "*60+").replace("s", "") + if request_duration[-1] == "+": + request_duration = request_duration[:-1] + request_duration = eval(request_duration) + for i in range(10): + sleep(request_duration // 10) + # add an event so that the chaos can start to apply + if i == 3: + event_records.insert("init_chaos", "ready") + for k, v in self.health_checkers.items(): + v.check_result() + if is_check: + assert_statistic(self.health_checkers, succ_rate_threshold=0.98) + assert_expectations() + # wait all pod ready + wait_pods_ready(self.milvus_ns, f"app.kubernetes.io/instance={self.release_name}") + time.sleep(60) + cc.check_thread_status(tasks) + for k, v in self.health_checkers.items(): + v.pause() + ra = ResultAnalyzer() + ra.get_stage_success_rate() + ra.show_result_table() + log.info("*********************Chaos Test Completed**********************") diff --git a/tests/python_client/chaos/testcases/test_concurrent_operation.py b/tests/python_client/chaos/testcases/test_concurrent_operation.py index e72ddcfbdc3a9..c582c5a803725 100644 --- a/tests/python_client/chaos/testcases/test_concurrent_operation.py +++ b/tests/python_client/chaos/testcases/test_concurrent_operation.py @@ -4,6 +4,7 @@ from time import sleep from pymilvus import connections from chaos.checker import (InsertChecker, + UpsertChecker, FlushChecker, SearchChecker, QueryChecker, @@ -70,6 +71,7 @@ def init_health_checkers(self, collection_name=None): c_name = collection_name checkers = { Op.insert: InsertChecker(collection_name=c_name), + Op.upsert: UpsertChecker(collection_name=c_name), Op.flush: FlushChecker(collection_name=c_name), Op.search: SearchChecker(collection_name=c_name), Op.query: QueryChecker(collection_name=c_name), diff --git a/tests/python_client/chaos/testcases/test_single_request_operation.py b/tests/python_client/chaos/testcases/test_single_request_operation.py index b7fa746ebbb3d..e5f38afcb3a3a 100644 --- a/tests/python_client/chaos/testcases/test_single_request_operation.py +++ b/tests/python_client/chaos/testcases/test_single_request_operation.py @@ -3,14 +3,15 @@ import pytest from time import sleep from pymilvus import connections -from chaos.checker import (CreateChecker, +from chaos.checker import (CollectionCreateChecker, InsertChecker, + UpsertChecker, FlushChecker, SearchChecker, QueryChecker, - IndexChecker, + IndexCreateChecker, DeleteChecker, - DropChecker, + CollectionDropChecker, Op, EventRecords, ResultAnalyzer @@ -61,14 +62,15 @@ def connection(self, host, port, user, password, milvus_ns): def init_health_checkers(self, collection_name=None): c_name = collection_name checkers = { - Op.create: CreateChecker(collection_name=c_name), + Op.create: CollectionCreateChecker(collection_name=c_name), Op.insert: InsertChecker(collection_name=c_name), + Op.upsert: UpsertChecker(collection_name=c_name), Op.flush: FlushChecker(collection_name=c_name), - Op.index: IndexChecker(collection_name=c_name), + Op.index: IndexCreateChecker(collection_name=c_name), Op.search: SearchChecker(collection_name=c_name), Op.query: QueryChecker(collection_name=c_name), Op.delete: DeleteChecker(collection_name=c_name), - Op.drop: DropChecker(collection_name=c_name) + Op.drop: CollectionDropChecker(collection_name=c_name) } self.health_checkers = checkers diff --git a/tests/python_client/chaos/testcases/test_single_request_operation_for_rolling_update.py b/tests/python_client/chaos/testcases/test_single_request_operation_for_rolling_update.py index 4b6ec7640c2da..1e241ee3ca34a 100644 --- a/tests/python_client/chaos/testcases/test_single_request_operation_for_rolling_update.py +++ b/tests/python_client/chaos/testcases/test_single_request_operation_for_rolling_update.py @@ -6,14 +6,15 @@ from yaml import full_load from pymilvus import connections, utility -from chaos.checker import (CreateChecker, +from chaos.checker import (CollectionCreateChecker, InsertChecker, + UpsertChecker, FlushChecker, SearchChecker, QueryChecker, - IndexChecker, + IndexCreateChecker, DeleteChecker, - DropChecker, + CollectionDropChecker, Op) from utils.util_k8s import wait_pods_ready from utils.util_log import test_log as log @@ -61,14 +62,15 @@ def init_health_checkers(self, collection_name=None): schema = cf.gen_default_collection_schema(auto_id=False) checkers = { - Op.create: CreateChecker(collection_name=None, schema=schema), + Op.create: CollectionCreateChecker(collection_name=None, schema=schema), Op.insert: InsertChecker(collection_name=c_name, schema=schema), + Op.upsert: UpsertChecker(collection_name=c_name, schema=schema), Op.flush: FlushChecker(collection_name=c_name, schema=schema), - Op.index: IndexChecker(collection_name=None, schema=schema), + Op.index: IndexCreateChecker(collection_name=None, schema=schema), Op.search: SearchChecker(collection_name=c_name, schema=schema), Op.query: QueryChecker(collection_name=c_name, schema=schema), Op.delete: DeleteChecker(collection_name=c_name, schema=schema), - Op.drop: DropChecker(collection_name=None, schema=schema) + Op.drop: CollectionDropChecker(collection_name=None, schema=schema) } self.health_checkers = checkers @@ -132,9 +134,9 @@ def test_operations(self, request_duration, is_check): v.pause() for k, v in self.health_checkers.items(): v.check_result() - for k, v in self.health_checkers.items(): + for k, v in self.health_checkers.items(): log.info(f"{k} failed request: {v.fail_records}") - for k, v in self.health_checkers.items(): + for k, v in self.health_checkers.items(): log.info(f"{k} rto: {v.get_rto()}") if is_check: assert_statistic(self.health_checkers, succ_rate_threshold=0.98) diff --git a/tests/python_client/chaos/testcases/test_single_request_operation_for_standby.py b/tests/python_client/chaos/testcases/test_single_request_operation_for_standby.py index ea82d14b81378..c2b6e8313c835 100644 --- a/tests/python_client/chaos/testcases/test_single_request_operation_for_standby.py +++ b/tests/python_client/chaos/testcases/test_single_request_operation_for_standby.py @@ -2,12 +2,12 @@ import threading from time import sleep from pymilvus import connections -from chaos.checker import (CreateChecker, +from chaos.checker import (CollectionCreateChecker, InsertChecker, FlushChecker, SearchChecker, QueryChecker, - IndexChecker, + IndexCreateChecker, DeleteChecker, Op) from utils.util_log import test_log as log @@ -60,10 +60,10 @@ def connection(self, host, port, user, password, milvus_ns): def init_health_checkers(self, collection_name=None): c_name = collection_name checkers = { - Op.create: CreateChecker(collection_name=c_name), + Op.create: CollectionCreateChecker(collection_name=c_name), Op.insert: InsertChecker(collection_name=c_name), Op.flush: FlushChecker(collection_name=c_name), - Op.index: IndexChecker(collection_name=c_name), + Op.index: IndexCreateChecker(collection_name=c_name), Op.search: SearchChecker(collection_name=c_name), Op.query: QueryChecker(collection_name=c_name), Op.delete: DeleteChecker(collection_name=c_name), @@ -102,4 +102,4 @@ def test_operations(self, request_duration, target_component, is_check): rto = v.get_rto() pytest.assume(rto < 30, f"{k} rto expect 30s but get {rto}s") # rto should be less than 30s - log.info("*********************Chaos Test Completed**********************") \ No newline at end of file + log.info("*********************Chaos Test Completed**********************") diff --git a/tests/python_client/chaos/testcases/test_verify_all_collections.py b/tests/python_client/chaos/testcases/test_verify_all_collections.py index 3d1315f2f39a0..38d420ec95d3f 100644 --- a/tests/python_client/chaos/testcases/test_verify_all_collections.py +++ b/tests/python_client/chaos/testcases/test_verify_all_collections.py @@ -3,10 +3,11 @@ from collections import defaultdict from pymilvus import connections from chaos.checker import (InsertChecker, - FlushChecker, + UpsertChecker, + FlushChecker, SearchChecker, QueryChecker, - IndexChecker, + IndexCreateChecker, DeleteChecker, Op) from utils.util_log import test_log as log @@ -67,14 +68,15 @@ def connection(self, host, port, user, password): self.host = host self.port = port self.user = user - self.password = password + self.password = password def init_health_checkers(self, collection_name=None): c_name = collection_name checkers = { Op.insert: InsertChecker(collection_name=c_name), + Op.upsert: UpsertChecker(collection_name=c_name), Op.flush: FlushChecker(collection_name=c_name), - Op.index: IndexChecker(collection_name=c_name), + Op.index: IndexCreateChecker(collection_name=c_name), Op.search: SearchChecker(collection_name=c_name), Op.query: QueryChecker(collection_name=c_name), Op.delete: DeleteChecker(collection_name=c_name), diff --git a/tests/python_client/loadbalance/test_auto_load_balance.py b/tests/python_client/loadbalance/test_auto_load_balance.py index 739d9950680eb..cb6b1b5ec1a13 100644 --- a/tests/python_client/loadbalance/test_auto_load_balance.py +++ b/tests/python_client/loadbalance/test_auto_load_balance.py @@ -1,7 +1,7 @@ from time import sleep from pymilvus import connections, list_collections, utility -from chaos.checker import (CreateChecker, InsertFlushChecker, - SearchChecker, QueryChecker, IndexChecker, Op) +from chaos.checker import (CollectionCreateChecker, InsertFlushChecker, + SearchChecker, QueryChecker, IndexCreateChecker, Op) from common.milvus_sys import MilvusSys from utils.util_log import test_log as log from chaos import chaos_commons as cc @@ -74,15 +74,15 @@ def test_auto_load_balance(self): conn = connections.connect("default", host=host, port=port) assert conn is not None self.health_checkers = { - Op.create: CreateChecker(), + Op.create: CollectionCreateChecker(), Op.insert: InsertFlushChecker(), Op.flush: InsertFlushChecker(flush=True), - Op.index: IndexChecker(), + Op.index: IndexCreateChecker(), Op.search: SearchChecker(), Op.query: QueryChecker() } cc.start_monitor_threads(self.health_checkers) - # wait + # wait sleep(constants.WAIT_PER_OP * 10) all_collections = list_collections() for c in all_collections: From 25a4525297413681f77b4ed351bc10800cf94c6d Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 14 Dec 2023 20:46:41 +0800 Subject: [PATCH 07/14] enhance: Change sync manager parallel config item (#29216) Since the sync manager is global in datanode now, the old `maxParallelSyncTaskNum` does not fit into current implementation anymore. This PR add a new param item for sync mgr parallel control and enlarge default value Signed-off-by: Congqi Xia --- internal/datanode/data_node.go | 2 +- pkg/util/paramtable/component_param.go | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index ffa7d2d74e11b..c18814900ed91 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -273,7 +273,7 @@ func (node *DataNode) Init() error { } node.chunkManager = chunkManager - syncMgr, err := syncmgr.NewSyncManager(paramtable.Get().DataNodeCfg.MaxParallelSyncTaskNum.GetAsInt(), + syncMgr, err := syncmgr.NewSyncManager(paramtable.Get().DataNodeCfg.MaxParallelSyncMgrTasks.GetAsInt(), node.chunkManager, node.allocator) if err != nil { initError = err diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index d3d5fa2acda2f..155b4dedc1da1 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -2589,6 +2589,7 @@ type dataNodeConfig struct { FlowGraphMaxQueueLength ParamItem `refreshable:"false"` FlowGraphMaxParallelism ParamItem `refreshable:"false"` MaxParallelSyncTaskNum ParamItem `refreshable:"false"` + MaxParallelSyncMgrTasks ParamItem `refreshable:"false"` // skip mode FlowGraphSkipModeEnable ParamItem `refreshable:"true"` @@ -2686,11 +2687,20 @@ func (p *dataNodeConfig) init(base *BaseTable) { Key: "dataNode.dataSync.maxParallelSyncTaskNum", Version: "2.3.0", DefaultValue: "6", - Doc: "Maximum number of sync tasks executed in parallel in each flush manager", + Doc: "deprecated, legacy flush manager max conurrency number", Export: true, } p.MaxParallelSyncTaskNum.Init(base.mgr) + p.MaxParallelSyncMgrTasks = ParamItem{ + Key: "dataNode.dataSync.maxParallelSyncMgrTasks", + Version: "2.3.4", + DefaultValue: "64", + Doc: "The max concurrent sync task number of datanode sync mgr globally", + Export: true, + } + p.MaxParallelSyncMgrTasks.Init(base.mgr) + p.FlushInsertBufferSize = ParamItem{ Key: "dataNode.segment.insertBufSize", Version: "2.0.0", From 4731c1b0d529722e2402beb3f8c788199df2afe7 Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 15 Dec 2023 09:58:43 +0800 Subject: [PATCH 08/14] enhance: make SyncManager pool size refreshable (#29224) See also #29223 This PR make `conc.Pool` resizable by adding `Resize` method for it. Also make newly added datanode `MaxParallelSyncMgrTasks` config refreshable --------- Signed-off-by: Congqi.Xia --- internal/datanode/data_node.go | 3 +- internal/datanode/mock_test.go | 2 +- .../datanode/syncmgr/key_lock_dispatcher.go | 5 +- internal/datanode/syncmgr/sync_manager.go | 54 +++++++++++++--- .../datanode/syncmgr/sync_manager_test.go | 63 +++++++++++++++++-- pkg/util/conc/pool.go | 13 ++++ pkg/util/conc/pool_test.go | 21 +++++++ pkg/util/paramtable/component_param.go | 2 +- 8 files changed, 143 insertions(+), 20 deletions(-) diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index c18814900ed91..5b89d835f0758 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -273,8 +273,7 @@ func (node *DataNode) Init() error { } node.chunkManager = chunkManager - syncMgr, err := syncmgr.NewSyncManager(paramtable.Get().DataNodeCfg.MaxParallelSyncMgrTasks.GetAsInt(), - node.chunkManager, node.allocator) + syncMgr, err := syncmgr.NewSyncManager(node.chunkManager, node.allocator) if err != nil { initError = err log.Error("failed to create sync manager", zap.Error(err)) diff --git a/internal/datanode/mock_test.go b/internal/datanode/mock_test.go index 66aec069e34b0..211310f599767 100644 --- a/internal/datanode/mock_test.go +++ b/internal/datanode/mock_test.go @@ -92,7 +92,7 @@ func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNod node.broker = broker node.timeTickSender = newTimeTickSender(node.broker, 0) - syncMgr, _ := syncmgr.NewSyncManager(10, node.chunkManager, node.allocator) + syncMgr, _ := syncmgr.NewSyncManager(node.chunkManager, node.allocator) node.syncMgr = syncMgr node.writeBufferManager = writebuffer.NewManager(node.syncMgr) diff --git a/internal/datanode/syncmgr/key_lock_dispatcher.go b/internal/datanode/syncmgr/key_lock_dispatcher.go index 493c53c57c6a5..6a51a0f7fbdcf 100644 --- a/internal/datanode/syncmgr/key_lock_dispatcher.go +++ b/internal/datanode/syncmgr/key_lock_dispatcher.go @@ -20,10 +20,11 @@ type keyLockDispatcher[K comparable] struct { } func newKeyLockDispatcher[K comparable](maxParallel int) *keyLockDispatcher[K] { - return &keyLockDispatcher[K]{ - workerPool: conc.NewPool[error](maxParallel, conc.WithPreAlloc(true)), + dispatcher := &keyLockDispatcher[K]{ + workerPool: conc.NewPool[error](maxParallel, conc.WithPreAlloc(false)), keyLock: lock.NewKeyLock[K](), } + return dispatcher } func (d *keyLockDispatcher[K]) Submit(key K, t Task, callbacks ...func(error)) *conc.Future[error] { diff --git a/internal/datanode/syncmgr/sync_manager.go b/internal/datanode/syncmgr/sync_manager.go index 9358a1f6383b0..b78c8cda2af3d 100644 --- a/internal/datanode/syncmgr/sync_manager.go +++ b/internal/datanode/syncmgr/sync_manager.go @@ -5,13 +5,18 @@ import ( "fmt" "strconv" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/datanode/metacache" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/config" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -57,19 +62,48 @@ type syncManager struct { tasks *typeutil.ConcurrentMap[string, Task] } -func NewSyncManager(parallelTask int, chunkManager storage.ChunkManager, allocator allocator.Interface) (SyncManager, error) { - if parallelTask < 1 { - return nil, merr.WrapErrParameterInvalid("positive parallel task number", strconv.FormatInt(int64(parallelTask), 10)) +func NewSyncManager(chunkManager storage.ChunkManager, allocator allocator.Interface) (SyncManager, error) { + params := paramtable.Get() + initPoolSize := params.DataNodeCfg.MaxParallelSyncMgrTasks.GetAsInt() + if initPoolSize < 1 { + return nil, merr.WrapErrParameterInvalid("positive parallel task number", strconv.FormatInt(int64(initPoolSize), 10)) } - return &syncManager{ - keyLockDispatcher: newKeyLockDispatcher[int64](parallelTask), + dispatcher := newKeyLockDispatcher[int64](initPoolSize) + log.Info("sync manager initialized", zap.Int("initPoolSize", initPoolSize)) + + syncMgr := &syncManager{ + keyLockDispatcher: dispatcher, chunkManager: chunkManager, allocator: allocator, tasks: typeutil.NewConcurrentMap[string, Task](), - }, nil + } + // setup config update watcher + params.Watch(params.DataNodeCfg.MaxParallelSyncMgrTasks.Key, config.NewHandler("datanode.syncmgr.poolsize", syncMgr.resizeHandler)) + + return syncMgr, nil +} + +func (mgr *syncManager) resizeHandler(evt *config.Event) { + if evt.HasUpdated { + log := log.Ctx(context.Background()).With( + zap.String("key", evt.Key), + zap.String("value", evt.Value), + ) + size, err := strconv.ParseInt(evt.Value, 10, 64) + if err != nil { + log.Warn("failed to parse new datanode syncmgr pool size", zap.Error(err)) + return + } + err = mgr.keyLockDispatcher.workerPool.Resize(int(size)) + if err != nil { + log.Warn("failed to resize datanode syncmgr pool size", zap.String("key", evt.Key), zap.String("value", evt.Value), zap.Error(err)) + return + } + log.Info("sync mgr pool size updated", zap.Int64("newSize", size)) + } } -func (mgr syncManager) SyncData(ctx context.Context, task Task) *conc.Future[error] { +func (mgr *syncManager) SyncData(ctx context.Context, task Task) *conc.Future[error] { switch t := task.(type) { case *SyncTask: t.WithAllocator(mgr.allocator).WithChunkManager(mgr.chunkManager) @@ -88,7 +122,7 @@ func (mgr syncManager) SyncData(ctx context.Context, task Task) *conc.Future[err }) } -func (mgr syncManager) GetEarliestPosition(channel string) (int64, *msgpb.MsgPosition) { +func (mgr *syncManager) GetEarliestPosition(channel string) (int64, *msgpb.MsgPosition) { var cp *msgpb.MsgPosition var segmentID int64 mgr.tasks.Range(func(_ string, task Task) bool { @@ -106,10 +140,10 @@ func (mgr syncManager) GetEarliestPosition(channel string) (int64, *msgpb.MsgPos return segmentID, cp } -func (mgr syncManager) Block(segmentID int64) { +func (mgr *syncManager) Block(segmentID int64) { mgr.keyLock.Lock(segmentID) } -func (mgr syncManager) Unblock(segmentID int64) { +func (mgr *syncManager) Unblock(segmentID int64) { mgr.keyLock.Unlock(segmentID) } diff --git a/internal/datanode/syncmgr/sync_manager_test.go b/internal/datanode/syncmgr/sync_manager_test.go index 6bfdb5fc4633f..953f5c4c49348 100644 --- a/internal/datanode/syncmgr/sync_manager_test.go +++ b/internal/datanode/syncmgr/sync_manager_test.go @@ -3,6 +3,7 @@ package syncmgr import ( "context" "math/rand" + "strconv" "testing" "time" @@ -21,6 +22,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tsoutil" ) @@ -41,7 +43,7 @@ type SyncManagerSuite struct { } func (s *SyncManagerSuite) SetupSuite() { - paramtable.Get().Init(paramtable.NewBaseTable()) + paramtable.Get().Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) s.collectionID = 100 s.partitionID = 101 @@ -155,7 +157,7 @@ func (s *SyncManagerSuite) TestSubmit() { s.metacache.EXPECT().GetSegmentsBy(mock.Anything).Return([]*metacache.SegmentInfo{seg}) s.metacache.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() - manager, err := NewSyncManager(10, s.chunkManager, s.allocator) + manager, err := NewSyncManager(s.chunkManager, s.allocator) s.NoError(err) task := s.getSuiteSyncTask() task.WithMetaWriter(BrokerMetaWriter(s.broker)) @@ -187,7 +189,7 @@ func (s *SyncManagerSuite) TestCompacted() { s.metacache.EXPECT().GetSegmentsBy(mock.Anything).Return([]*metacache.SegmentInfo{seg}) s.metacache.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() - manager, err := NewSyncManager(10, s.chunkManager, s.allocator) + manager, err := NewSyncManager(s.chunkManager, s.allocator) s.NoError(err) task := s.getSuiteSyncTask() task.WithMetaWriter(BrokerMetaWriter(s.broker)) @@ -225,7 +227,7 @@ func (s *SyncManagerSuite) TestBlock() { } }) - manager, err := NewSyncManager(10, s.chunkManager, s.allocator) + manager, err := NewSyncManager(s.chunkManager, s.allocator) s.NoError(err) // block @@ -253,6 +255,59 @@ func (s *SyncManagerSuite) TestBlock() { <-sig } +func (s *SyncManagerSuite) TestResizePool() { + manager, err := NewSyncManager(s.chunkManager, s.allocator) + s.NoError(err) + + syncMgr, ok := manager.(*syncManager) + s.Require().True(ok) + + cap := syncMgr.keyLockDispatcher.workerPool.Cap() + s.NotZero(cap) + + params := paramtable.Get() + configKey := params.DataNodeCfg.MaxParallelSyncMgrTasks.Key + + syncMgr.resizeHandler(&config.Event{ + Key: configKey, + Value: "abc", + HasUpdated: true, + }) + + s.Equal(cap, syncMgr.keyLockDispatcher.workerPool.Cap()) + + syncMgr.resizeHandler(&config.Event{ + Key: configKey, + Value: "-1", + HasUpdated: true, + }) + s.Equal(cap, syncMgr.keyLockDispatcher.workerPool.Cap()) + + syncMgr.resizeHandler(&config.Event{ + Key: configKey, + Value: strconv.FormatInt(int64(cap*2), 10), + HasUpdated: true, + }) + s.Equal(cap*2, syncMgr.keyLockDispatcher.workerPool.Cap()) +} + +func (s *SyncManagerSuite) TestNewSyncManager() { + manager, err := NewSyncManager(s.chunkManager, s.allocator) + s.NoError(err) + + _, ok := manager.(*syncManager) + s.Require().True(ok) + + params := paramtable.Get() + configKey := params.DataNodeCfg.MaxParallelSyncMgrTasks.Key + defer params.Reset(configKey) + + params.Save(configKey, "0") + + _, err = NewSyncManager(s.chunkManager, s.allocator) + s.Error(err) +} + func TestSyncManager(t *testing.T) { suite.Run(t, new(SyncManagerSuite)) } diff --git a/pkg/util/conc/pool.go b/pkg/util/conc/pool.go index d5b3e286e7e65..8c6c1fb25cfd6 100644 --- a/pkg/util/conc/pool.go +++ b/pkg/util/conc/pool.go @@ -18,12 +18,14 @@ package conc import ( "fmt" + "strconv" "sync" ants "github.com/panjf2000/ants/v2" "github.com/milvus-io/milvus/pkg/util/generic" "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/merr" ) // A goroutine pool @@ -110,6 +112,17 @@ func (pool *Pool[T]) Release() { pool.inner.Release() } +func (pool *Pool[T]) Resize(size int) error { + if pool.opt.preAlloc { + return merr.WrapErrServiceInternal("cannot resize pre-alloc pool") + } + if size <= 0 { + return merr.WrapErrParameterInvalid("positive size", strconv.FormatInt(int64(size), 10)) + } + pool.inner.Tune(size) + return nil +} + // WarmupPool do warm up logic for each goroutine in pool func WarmupPool[T any](pool *Pool[T], warmup func()) { cap := pool.Cap() diff --git a/pkg/util/conc/pool_test.go b/pkg/util/conc/pool_test.go index f6fcf4ca50242..3c09fc6b8a308 100644 --- a/pkg/util/conc/pool_test.go +++ b/pkg/util/conc/pool_test.go @@ -21,6 +21,8 @@ import ( "time" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/hardware" ) func TestPool(t *testing.T) { @@ -55,6 +57,25 @@ func TestPool(t *testing.T) { } } +func TestPoolResize(t *testing.T) { + cpuNum := hardware.GetCPUNum() + + pool := NewPool[any](cpuNum) + + assert.Equal(t, cpuNum, pool.Cap()) + + err := pool.Resize(cpuNum * 2) + assert.NoError(t, err) + assert.Equal(t, cpuNum*2, pool.Cap()) + + err = pool.Resize(0) + assert.Error(t, err) + + pool = NewDefaultPool[any]() + err = pool.Resize(cpuNum * 2) + assert.Error(t, err) +} + func TestPoolWithPanic(t *testing.T) { pool := NewPool[any](1, WithConcealPanic(true)) diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 155b4dedc1da1..07f2b68509d24 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -2589,7 +2589,7 @@ type dataNodeConfig struct { FlowGraphMaxQueueLength ParamItem `refreshable:"false"` FlowGraphMaxParallelism ParamItem `refreshable:"false"` MaxParallelSyncTaskNum ParamItem `refreshable:"false"` - MaxParallelSyncMgrTasks ParamItem `refreshable:"false"` + MaxParallelSyncMgrTasks ParamItem `refreshable:"true"` // skip mode FlowGraphSkipModeEnable ParamItem `refreshable:"true"` From 2f7252b44effe5aa08bda3fbfe66387e2b0a0c1a Mon Sep 17 00:00:00 2001 From: "cai.zhang" Date: Fri, 15 Dec 2023 12:06:39 +0800 Subject: [PATCH 09/14] enhance: Set default index name as field name (#29218) Signed-off-by: Cai Zhang --- internal/proxy/task_index.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 28ff1d7fa85d3..e89e28d25fa25 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -19,7 +19,6 @@ package proxy import ( "context" "fmt" - "strconv" "github.com/cockroachdb/errors" "go.uber.org/zap" @@ -399,7 +398,7 @@ func (cit *createIndexTask) Execute(ctx context.Context) error { ) if cit.req.GetIndexName() == "" { - cit.req.IndexName = Params.CommonCfg.DefaultIndexName.GetValue() + "_" + strconv.FormatInt(cit.fieldSchema.GetFieldID(), 10) + cit.req.IndexName = cit.fieldSchema.GetName() } var err error req := &indexpb.CreateIndexRequest{ From 26409d801ee854951eae0efff47cac13317c11b4 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Fri, 15 Dec 2023 14:00:39 +0800 Subject: [PATCH 10/14] enhance: Remove omp from segcore (#29207) Signed-off-by: Yudong Cai --- internal/core/src/common/RangeSearchHelper.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/core/src/common/RangeSearchHelper.cpp b/internal/core/src/common/RangeSearchHelper.cpp index 9e51dac1e6541..7d099f0ec28e3 100644 --- a/internal/core/src/common/RangeSearchHelper.cpp +++ b/internal/core/src/common/RangeSearchHelper.cpp @@ -82,7 +82,6 @@ ReGenRangeSearchResult(DatasetPtr data_set, } // The subscript of p_id and p_dist -#pragma omp parallel for for (int i = 0; i < nq; i++) { std::priority_queue, decltype(cmp)> pq(cmp); From 5164377e6886be4c779f5e1c52f0314d4f281921 Mon Sep 17 00:00:00 2001 From: XuanYang-cn Date: Fri, 15 Dec 2023 16:04:45 +0800 Subject: [PATCH 11/14] fix: Skip updating checkpoint after dropcollection (#29220) Signed-off-by: yangxuan --- internal/datanode/flow_graph_dd_node.go | 2 +- internal/datanode/flow_graph_time_tick_node.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/internal/datanode/flow_graph_dd_node.go b/internal/datanode/flow_graph_dd_node.go index b79ea60ceeadc..cf85a36776698 100644 --- a/internal/datanode/flow_graph_dd_node.go +++ b/internal/datanode/flow_graph_dd_node.go @@ -115,7 +115,7 @@ func (ddn *ddNode) Operate(in []Msg) []Msg { } if load := ddn.dropMode.Load(); load != nil && load.(bool) { - log.Info("ddNode in dropMode", + log.RatedInfo(1.0, "ddNode in dropMode", zap.String("vChannelName", ddn.vChannelName), zap.Int64("collectionID", ddn.collectionID)) return []Msg{} diff --git a/internal/datanode/flow_graph_time_tick_node.go b/internal/datanode/flow_graph_time_tick_node.go index f1f65d4e2939a..4b203d1f52daf 100644 --- a/internal/datanode/flow_graph_time_tick_node.go +++ b/internal/datanode/flow_graph_time_tick_node.go @@ -42,6 +42,7 @@ type ttNode struct { writeBufferManager writebuffer.BufferManager lastUpdateTime *atomic.Time cpUpdater *channelCheckpointUpdater + dropMode *atomic.Bool } // Name returns node name, implementing flowgraph.Node @@ -67,6 +68,17 @@ func (ttn *ttNode) Close() { // Operate handles input messages, implementing flowgraph.Node func (ttn *ttNode) Operate(in []Msg) []Msg { fgMsg := in[0].(*flowGraphMsg) + if fgMsg.dropCollection { + ttn.dropMode.Store(true) + } + + // skip updating checkpoint for drop collection + // even if its the close msg + if ttn.dropMode.Load() { + log.RatedInfo(1.0, "ttnode in dropMode", zap.String("channel", ttn.vChannelName)) + return []Msg{} + } + curTs, _ := tsoutil.ParseTS(fgMsg.timeRange.timestampMax) if fgMsg.IsCloseMsg() { if len(fgMsg.endPositions) > 0 { @@ -129,6 +141,7 @@ func newTTNode(config *nodeConfig, wbManager writebuffer.BufferManager, cpUpdate writeBufferManager: wbManager, lastUpdateTime: atomic.NewTime(time.Time{}), // set to Zero to update channel checkpoint immediately after fg started cpUpdater: cpUpdater, + dropMode: atomic.NewBool(false), } return tt, nil From 5da0c8d8e3f213745f8d565f1a9c6ae2b607c96c Mon Sep 17 00:00:00 2001 From: nico <109071306+NicoYuan1986@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:26:48 +0800 Subject: [PATCH 12/14] enhance: [skip e2e]update chart version (#29041) Signed-off-by: nico --- ci/jenkins/Nightly.groovy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/jenkins/Nightly.groovy b/ci/jenkins/Nightly.groovy index ed327b5d1c153..a461e3aeaf744 100644 --- a/ci/jenkins/Nightly.groovy +++ b/ci/jenkins/Nightly.groovy @@ -8,7 +8,7 @@ String cron_string = BRANCH_NAME == "master" ? "50 1 * * * " : "" // Make timeout 4 hours so that we can run two nightly during the ci int total_timeout_minutes = 7 * 60 def imageTag='' -def chart_version='4.0.6' +def chart_version='4.1.10' pipeline { triggers { cron """${cron_timezone} From 1f1a8b7770116cfd4f1e6ddf04314a151fb006fd Mon Sep 17 00:00:00 2001 From: nico <109071306+NicoYuan1986@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:28:42 +0800 Subject: [PATCH 13/14] enhance: modify test cases for error msg update (#29136) Signed-off-by: nico --- tests/python_client/common/common_func.py | 2 +- tests/python_client/testcases/test_search.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index e6bb455d03728..66f8aa3a8aa95 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -944,7 +944,7 @@ def gen_invalid_search_params_type(): scann_search_param = {"index_type": index_type, "search_params": {"nprobe": 8, "reorder_k": reorder_k}} search_params.append(scann_search_param) elif index_type == "DISKANN": - for search_list in ct.get_invalid_ints: + for search_list in ct.get_invalid_ints[1:]: diskann_search_param = {"index_type": index_type, "search_params": {"search_list": search_list}} search_params.append(diskann_search_param) return search_params diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index 61896d0a078a4..e7057c7731e10 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -254,7 +254,7 @@ def test_search_param_invalid_dim(self): default_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, check_items={"err_code": 65535, - "err_msg": 'failed to search'}) + "err_msg": 'vector dimension mismatch'}) @pytest.mark.tags(CaseLabel.L2) def test_search_param_invalid_field_type(self, get_invalid_fields_type): @@ -837,7 +837,7 @@ def test_search_different_index_invalid_params(self, index, params): search_params[0], default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 65535, "err_msg": "failed to search"}) + check_items={"err_code": 65535, "err_msg": "type must be number, but is string"}) @pytest.mark.tags(CaseLabel.L2) def test_search_index_partition_not_existed(self): From 88b4b8b77cb13999eb04cc0f711640dc2b06edc8 Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 15 Dec 2023 20:22:39 +0800 Subject: [PATCH 14/14] enhance: make segments SQPool & LoadPool resizable (#29239) See also #29223 --------- Signed-off-by: Congqi Xia --- internal/querynodev2/segments/pool.go | 56 ++++++++++++- internal/querynodev2/segments/pool_test.go | 93 ++++++++++++++++++++++ pkg/util/paramtable/component_param.go | 2 +- 3 files changed, 146 insertions(+), 5 deletions(-) create mode 100644 internal/querynodev2/segments/pool_test.go diff --git a/internal/querynodev2/segments/pool.go b/internal/querynodev2/segments/pool.go index 29c6c65e56bb2..4119cd2e99558 100644 --- a/internal/querynodev2/segments/pool.go +++ b/internal/querynodev2/segments/pool.go @@ -17,12 +17,16 @@ package segments import ( + "context" "math" "runtime" "sync" "go.uber.org/atomic" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/config" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -45,14 +49,17 @@ var ( func initSQPool() { sqOnce.Do(func() { pt := paramtable.Get() + initPoolSize := int(math.Ceil(pt.QueryNodeCfg.MaxReadConcurrency.GetAsFloat() * pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())) pool := conc.NewPool[any]( - int(math.Ceil(pt.QueryNodeCfg.MaxReadConcurrency.GetAsFloat()*pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())), - conc.WithPreAlloc(true), + initPoolSize, + conc.WithPreAlloc(false), // pre alloc must be false to resize pool dynamically, use warmup to alloc worker here conc.WithDisablePurge(true), ) conc.WarmupPool(pool, runtime.LockOSThread) - sqp.Store(pool) + + pt.Watch(pt.QueryNodeCfg.MaxReadConcurrency.Key, config.NewHandler("qn.sqpool.maxconc", ResizeSQPool)) + pt.Watch(pt.QueryNodeCfg.CGOPoolSizeRatio.Key, config.NewHandler("qn.sqpool.cgopoolratio", ResizeSQPool)) }) } @@ -71,14 +78,17 @@ func initDynamicPool() { func initLoadPool() { loadOnce.Do(func() { + pt := paramtable.Get() pool := conc.NewPool[any]( - hardware.GetCPUNum()*paramtable.Get().CommonCfg.MiddlePriorityThreadCoreCoefficient.GetAsInt(), + hardware.GetCPUNum()*pt.CommonCfg.MiddlePriorityThreadCoreCoefficient.GetAsInt(), conc.WithPreAlloc(false), conc.WithDisablePurge(false), conc.WithPreHandler(runtime.LockOSThread), // lock os thread for cgo thread disposal ) loadPool.Store(pool) + + pt.Watch(pt.CommonCfg.MiddlePriorityThreadCoreCoefficient.Key, config.NewHandler("qn.loadpool.middlepriority", ResizeLoadPool)) }) } @@ -98,3 +108,41 @@ func GetLoadPool() *conc.Pool[any] { initLoadPool() return loadPool.Load() } + +func ResizeSQPool(evt *config.Event) { + if evt.HasUpdated { + pt := paramtable.Get() + newSize := int(math.Ceil(pt.QueryNodeCfg.MaxReadConcurrency.GetAsFloat() * pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())) + pool := GetSQPool() + resizePool(pool, newSize, "SQPool") + conc.WarmupPool(pool, runtime.LockOSThread) + } +} + +func ResizeLoadPool(evt *config.Event) { + if evt.HasUpdated { + pt := paramtable.Get() + newSize := hardware.GetCPUNum() * pt.CommonCfg.MiddlePriorityThreadCoreCoefficient.GetAsInt() + resizePool(GetLoadPool(), newSize, "LoadPool") + } +} + +func resizePool(pool *conc.Pool[any], newSize int, tag string) { + log := log.Ctx(context.Background()). + With( + zap.String("poolTag", tag), + zap.Int("newSize", newSize), + ) + + if newSize <= 0 { + log.Warn("cannot set pool size to non-positive value") + return + } + + err := pool.Resize(newSize) + if err != nil { + log.Warn("failed to resize pool", zap.Error(err)) + return + } + log.Info("pool resize successfully") +} diff --git a/internal/querynodev2/segments/pool_test.go b/internal/querynodev2/segments/pool_test.go new file mode 100644 index 0000000000000..6c817bdb1eb9a --- /dev/null +++ b/internal/querynodev2/segments/pool_test.go @@ -0,0 +1,93 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package segments + +import ( + "math" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/config" + "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestResizePools(t *testing.T) { + paramtable.Get().Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) + pt := paramtable.Get() + + defer func() { + pt.Reset(pt.QueryNodeCfg.MaxReadConcurrency.Key) + pt.Reset(pt.QueryNodeCfg.CGOPoolSizeRatio.Key) + pt.Reset(pt.CommonCfg.MiddlePriorityThreadCoreCoefficient.Key) + }() + + t.Run("SQPool", func(t *testing.T) { + expectedCap := int(math.Ceil(pt.QueryNodeCfg.MaxReadConcurrency.GetAsFloat() * pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())) + + ResizeSQPool(&config.Event{ + HasUpdated: true, + }) + assert.Equal(t, expectedCap, GetSQPool().Cap()) + + pt.Save(pt.QueryNodeCfg.CGOPoolSizeRatio.Key, strconv.FormatFloat(pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat()*2, 'f', 10, 64)) + expectedCap = int(math.Ceil(pt.QueryNodeCfg.MaxReadConcurrency.GetAsFloat() * pt.QueryNodeCfg.CGOPoolSizeRatio.GetAsFloat())) + ResizeSQPool(&config.Event{ + HasUpdated: true, + }) + assert.Equal(t, expectedCap, GetSQPool().Cap()) + + pt.Save(pt.QueryNodeCfg.CGOPoolSizeRatio.Key, "0") + ResizeSQPool(&config.Event{ + HasUpdated: true, + }) + assert.Equal(t, expectedCap, GetSQPool().Cap(), "pool shall not be resized when newSize is 0") + }) + + t.Run("LoadPool", func(t *testing.T) { + expectedCap := hardware.GetCPUNum() * pt.CommonCfg.MiddlePriorityThreadCoreCoefficient.GetAsInt() + + ResizeLoadPool(&config.Event{ + HasUpdated: true, + }) + assert.Equal(t, expectedCap, GetLoadPool().Cap()) + + pt.Save(pt.CommonCfg.MiddlePriorityThreadCoreCoefficient.Key, strconv.FormatFloat(pt.CommonCfg.MiddlePriorityThreadCoreCoefficient.GetAsFloat()*2, 'f', 10, 64)) + ResizeLoadPool(&config.Event{ + HasUpdated: true, + }) + assert.Equal(t, expectedCap, GetLoadPool().Cap()) + + pt.Save(pt.CommonCfg.MiddlePriorityThreadCoreCoefficient.Key, "0") + ResizeLoadPool(&config.Event{ + HasUpdated: true, + }) + assert.Equal(t, expectedCap, GetLoadPool().Cap()) + }) + + t.Run("error_pool", func(*testing.T) { + pool := conc.NewDefaultPool[any]() + c := pool.Cap() + + resizePool(pool, c*2, "debug") + + assert.Equal(t, c, pool.Cap()) + }) +} diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 07f2b68509d24..dc02775ce1f1e 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1703,7 +1703,7 @@ type queryNodeConfig struct { SchedulePolicyMaxPendingTaskPerUser ParamItem `refreshable:"true"` // CGOPoolSize ratio to MaxReadConcurrency - CGOPoolSizeRatio ParamItem `refreshable:"false"` + CGOPoolSizeRatio ParamItem `refreshable:"true"` EnableWorkerSQCostMetrics ParamItem `refreshable:"true"` }