From c64a078458d8fea03a4da0ce5a59142ba4223529 Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 2 Aug 2024 11:24:19 +0800 Subject: [PATCH] enhance: Support proxy/delegator qn client pooling (#35194) See also #35196 Add param item for proxy/delegator query node client pooling and implement pooling logic --------- Signed-off-by: Congqi Xia --- configs/milvus.yaml | 4 ++ internal/proxy/shard_client.go | 38 +++++++++++++- internal/querynodev2/cluster/worker.go | 69 ++++++++++++++++++++++---- internal/querynodev2/server.go | 9 ++-- pkg/util/paramtable/component_param.go | 22 ++++++++ 5 files changed, 123 insertions(+), 19 deletions(-) diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 5dc3fb6d7836a..af51ae932234a 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -299,6 +299,8 @@ proxy: maxConnectionNum: 10000 # the max client info numbers that proxy should manage, avoid too many client infos gracefulStopTimeout: 30 # seconds. force stop node without graceful stop slowQuerySpanInSeconds: 5 # query whose executed time exceeds the `slowQuerySpanInSeconds` can be considered slow, in seconds. + queryNodePooling: + size: 10 # the size for shardleader(querynode) client pool http: enabled: true # Whether to enable the http server debug_mode: false # Whether to enable http server debug mode @@ -451,6 +453,8 @@ queryNode: enableSegmentPrune: false # use partition stats to prune data in search/query on shard delegator queryStreamBatchSize: 4194304 # return batch size of stream query bloomFilterApplyParallelFactor: 4 # parallel factor when to apply pk to bloom filter, default to 4*CPU_CORE_NUM + workerPooling: + size: 10 # the size for worker querynode client pool ip: # TCP/IP address of queryNode. If not specified, use the first unicastable address port: 21123 # TCP port of queryNode grpc: diff --git a/internal/proxy/shard_client.go b/internal/proxy/shard_client.go index c250de1d6aab4..237b3b3e8a0a3 100644 --- a/internal/proxy/shard_client.go +++ b/internal/proxy/shard_client.go @@ -6,11 +6,13 @@ import ( "sync" "github.com/cockroachdb/errors" + "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/registry" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) @@ -32,6 +34,10 @@ type shardClient struct { client types.QueryNodeClient isClosed bool refCnt int + clients []types.QueryNodeClient + idx atomic.Int64 + poolSize int + pooling bool } func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, error) { @@ -40,6 +46,10 @@ func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, err if n.isClosed { return nil, errClosed } + if n.pooling { + idx := n.idx.Inc() + return n.clients[int(idx)%n.poolSize], nil + } return n.client, nil } @@ -96,6 +106,31 @@ func newShardClient(info *nodeInfo, client types.QueryNodeClient) *shardClient { return ret } +func newPoolingShardClient(info *nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) { + num := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt() + if num <= 0 { + num = 1 + } + clients := make([]types.QueryNodeClient, 0, num) + for i := 0; i < num; i++ { + client, err := creator(context.Background(), info.address, info.nodeID) + if err != nil { + return nil, err + } + clients = append(clients, client) + } + return &shardClient{ + info: nodeInfo{ + nodeID: info.nodeID, + address: info.address, + }, + refCnt: 1, + pooling: true, + clients: clients, + poolSize: num, + }, nil +} + type shardClientMgr interface { GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error @@ -182,11 +217,10 @@ func (c *shardClientMgrImpl) UpdateShardLeaders(oldLeaders map[string][]nodeInfo if c.clientCreator == nil { return fmt.Errorf("clientCreator function is nil") } - shardClient, err := c.clientCreator(context.Background(), node.address, node.nodeID) + client, err := newPoolingShardClient(node, c.clientCreator) if err != nil { return err } - client := newShardClient(node, shardClient) c.clients.data[node.nodeID] = client } } diff --git a/internal/querynodev2/cluster/worker.go b/internal/querynodev2/cluster/worker.go index 349f7f79d5689..af524cb899cc1 100644 --- a/internal/querynodev2/cluster/worker.go +++ b/internal/querynodev2/cluster/worker.go @@ -22,6 +22,7 @@ import ( "io" "github.com/cockroachdb/errors" + "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -30,6 +31,7 @@ import ( "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) // Worker is the interface definition for querynode worker role. @@ -48,22 +50,56 @@ type Worker interface { // remoteWorker wraps grpc QueryNode client as Worker. type remoteWorker struct { - client types.QueryNodeClient + client types.QueryNodeClient + clients []types.QueryNodeClient + poolSize int + idx atomic.Int64 + pooling bool } // NewRemoteWorker creates a grpcWorker. func NewRemoteWorker(client types.QueryNodeClient) Worker { return &remoteWorker{ - client: client, + client: client, + pooling: false, } } +func NewPoolingRemoteWorker(fn func() (types.QueryNodeClient, error)) (Worker, error) { + num := paramtable.Get().QueryNodeCfg.WorkerPoolingSize.GetAsInt() + if num <= 0 { + num = 1 + } + clients := make([]types.QueryNodeClient, 0, num) + for i := 0; i < num; i++ { + c, err := fn() + if err != nil { + return nil, err + } + clients = append(clients, c) + } + return &remoteWorker{ + pooling: true, + clients: clients, + poolSize: num, + }, nil +} + +func (w *remoteWorker) getClient() types.QueryNodeClient { + if w.pooling { + idx := w.idx.Inc() + return w.clients[int(idx)%w.poolSize] + } + return w.client +} + // LoadSegments implements Worker. func (w *remoteWorker) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error { log := log.Ctx(ctx).With( zap.Int64("workerID", req.GetDstNodeID()), ) - status, err := w.client.LoadSegments(ctx, req) + client := w.getClient() + status, err := client.LoadSegments(ctx, req) if err = merr.CheckRPCCall(status, err); err != nil { log.Warn("failed to call LoadSegments via grpc worker", zap.Error(err), @@ -77,7 +113,8 @@ func (w *remoteWorker) ReleaseSegments(ctx context.Context, req *querypb.Release log := log.Ctx(ctx).With( zap.Int64("workerID", req.GetNodeID()), ) - status, err := w.client.ReleaseSegments(ctx, req) + client := w.getClient() + status, err := client.ReleaseSegments(ctx, req) if err = merr.CheckRPCCall(status, err); err != nil { log.Warn("failed to call ReleaseSegments via grpc worker", zap.Error(err), @@ -91,7 +128,8 @@ func (w *remoteWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) e log := log.Ctx(ctx).With( zap.Int64("workerID", req.GetBase().GetTargetID()), ) - status, err := w.client.Delete(ctx, req) + client := w.getClient() + status, err := client.Delete(ctx, req) if err := merr.CheckRPCCall(status, err); err != nil { if errors.Is(err, merr.ErrServiceUnimplemented) { log.Warn("invoke legacy querynode Delete method, ignore error", zap.Error(err)) @@ -104,27 +142,30 @@ func (w *remoteWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) e } func (w *remoteWorker) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { - ret, err := w.client.SearchSegments(ctx, req) + client := w.getClient() + ret, err := client.SearchSegments(ctx, req) if err != nil && errors.Is(err, merr.ErrServiceUnimplemented) { // for compatible with rolling upgrade from version before v2.2.9 - return w.client.Search(ctx, req) + return client.Search(ctx, req) } return ret, err } func (w *remoteWorker) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { - ret, err := w.client.QuerySegments(ctx, req) + client := w.getClient() + ret, err := client.QuerySegments(ctx, req) if err != nil && errors.Is(err, merr.ErrServiceUnimplemented) { // for compatible with rolling upgrade from version before v2.2.9 - return w.client.Query(ctx, req) + return client.Query(ctx, req) } return ret, err } func (w *remoteWorker) QueryStreamSegments(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { - client, err := w.client.QueryStreamSegments(ctx, req) + c := w.getClient() + client, err := c.QueryStreamSegments(ctx, req) if err != nil { return err } @@ -155,7 +196,8 @@ func (w *remoteWorker) QueryStreamSegments(ctx context.Context, req *querypb.Que } func (w *remoteWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) { - return w.client.GetStatistics(ctx, req) + client := w.getClient() + return client.GetStatistics(ctx, req) } func (w *remoteWorker) IsHealthy() bool { @@ -163,6 +205,11 @@ func (w *remoteWorker) IsHealthy() bool { } func (w *remoteWorker) Stop() { + if w.pooling { + for _, client := range w.clients { + client.Close() + } + } if err := w.client.Close(); err != nil { log.Warn("failed to call Close via grpc worker", zap.Error(err)) } diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 508eb6e484da4..cac33a6a99c38 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -337,12 +337,9 @@ func (node *QueryNode) Init() error { } } - client, err := grpcquerynodeclient.NewClient(node.ctx, addr, nodeID) - if err != nil { - return nil, err - } - - return cluster.NewRemoteWorker(client), nil + return cluster.NewPoolingRemoteWorker(func() (types.QueryNodeClient, error) { + return grpcquerynodeclient.NewClient(node.ctx, addr, nodeID) + }) }) node.delegators = typeutil.NewConcurrentMap[string, delegator.ShardDelegator]() node.subscribingChannels = typeutil.NewConcurrentSet[string]() diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 41265dae09d79..36e0c6027db51 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1209,6 +1209,7 @@ type proxyConfig struct { GracefulStopTimeout ParamItem `refreshable:"true"` SlowQuerySpanInSeconds ParamItem `refreshable:"true"` + QueryNodePoolingSize ParamItem `refreshable:"false"` } func (p *proxyConfig) init(base *BaseTable) { @@ -1611,6 +1612,15 @@ please adjust in embedded Milvus: false`, Export: true, } p.SlowQuerySpanInSeconds.Init(base.mgr) + + p.QueryNodePoolingSize = ParamItem{ + Key: "proxy.queryNodePooling.size", + Version: "2.4.7", + Doc: "the size for shardleader(querynode) client pool", + DefaultValue: "10", + Export: true, + } + p.QueryNodePoolingSize.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// @@ -2314,6 +2324,9 @@ type queryNodeConfig struct { UseStreamComputing ParamItem `refreshable:"false"` QueryStreamBatchSize ParamItem `refreshable:"false"` BloomFilterApplyParallelFactor ParamItem `refreshable:"true"` + + // worker + WorkerPoolingSize ParamItem `refreshable:"false"` } func (p *queryNodeConfig) init(base *BaseTable) { @@ -2955,6 +2968,15 @@ user-task-polling: Export: true, } p.BloomFilterApplyParallelFactor.Init(base.mgr) + + p.WorkerPoolingSize = ParamItem{ + Key: "queryNode.workerPooling.size", + Version: "2.4.7", + Doc: "the size for worker querynode client pool", + DefaultValue: "10", + Export: true, + } + p.WorkerPoolingSize.Init(base.mgr) } // /////////////////////////////////////////////////////////////////////////////