diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index e0bb9794d2eff..0201bfec2b480 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -98,6 +98,23 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) { } } +// GetShardLeaders should always retry until ctx done, except the collection is not loaded. +func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, collName string, collectionID int64, withCache bool) (map[string][]nodeInfo, error) { + var shardLeaders map[string][]nodeInfo + // use retry to handle query coord service not ready + err := retry.Handle(ctx, func() (bool, error) { + var err error + shardLeaders, err = globalMetaCache.GetShards(ctx, withCache, dbName, collName, collectionID) + if err != nil { + return !errors.Is(err, merr.ErrCollectionLoaded), err + } + + return false, nil + }) + + return shardLeaders, err +} + // try to select the best node from the available nodes func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) { availableNodes := lo.FilterMap(workload.shardLeaders, func(node int64, _ int) (int64, bool) { return node, !excludeNodes.Contain(node) }) @@ -105,7 +122,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor if err != nil { log := log.Ctx(ctx) globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName) - shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collectionName, workload.collectionID) + shardLeaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, false) if err != nil { log.Warn("failed to get shard delegator", zap.Int64("collectionID", workload.collectionID), @@ -195,7 +212,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo // Execute will execute collection workload in parallel func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error { - dml2leaders, err := globalMetaCache.GetShards(ctx, true, workload.db, workload.collectionName, workload.collectionID) + dml2leaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, true) if err != nil { log.Ctx(ctx).Warn("failed to get shards", zap.Error(err)) return err diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 1d360f758c734..34c6db0817723 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -1004,9 +1004,7 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col if _, ok := m.collLeader[database]; !ok { m.collLeader[database] = make(map[string]*shardLeaders) } - m.collLeader[database][collectionName] = newShardLeaders - m.leaderMut.Unlock() iterator := newShardLeaders.GetReader() ret := iterator.Shuffle() @@ -1016,8 +1014,10 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col oldLeaders = cacheShardLeaders.shardLeaders } // update refcnt in shardClientMgr - // and create new client for new leaders + // update shard leader's just create a empty client pool + // and init new client will be execute in getClient _ = m.shardMgr.UpdateShardLeaders(oldLeaders, ret) + m.leaderMut.Unlock() metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return ret, nil diff --git a/internal/proxy/shard_client.go b/internal/proxy/shard_client.go index 237b3b3e8a0a3..8475494e6659c 100644 --- a/internal/proxy/shard_client.go +++ b/internal/proxy/shard_client.go @@ -31,26 +31,38 @@ var errClosed = errors.New("client is closed") type shardClient struct { sync.RWMutex info nodeInfo - client types.QueryNodeClient isClosed bool refCnt int clients []types.QueryNodeClient idx atomic.Int64 poolSize int pooling bool + + initialized atomic.Bool + creator queryNodeCreatorFunc } func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, error) { + if !n.initialized.Load() { + n.Lock() + if !n.initialized.Load() { + if err := n.initClients(); err != nil { + n.Unlock() + return nil, err + } + n.initialized.Store(true) + } + n.Unlock() + } + n.RLock() defer n.RUnlock() if n.isClosed { return nil, errClosed } - if n.pooling { - idx := n.idx.Inc() - return n.clients[int(idx)%n.poolSize], nil - } - return n.client, nil + + idx := n.idx.Inc() + return n.clients[int(idx)%n.poolSize], nil } func (n *shardClient) inc() { @@ -65,12 +77,13 @@ func (n *shardClient) inc() { func (n *shardClient) close() { n.isClosed = true n.refCnt = 0 - if n.client != nil { - if err := n.client.Close(); err != nil { + + for _, client := range n.clients { + if err := client.Close(); err != nil { log.Warn("close grpc client failed", zap.Error(err)) } - n.client = nil } + n.clients = nil } func (n *shardClient) dec() bool { @@ -94,41 +107,39 @@ func (n *shardClient) Close() { n.close() } -func newShardClient(info *nodeInfo, client types.QueryNodeClient) *shardClient { - ret := &shardClient{ +func newPoolingShardClient(info *nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) { + return &shardClient{ info: nodeInfo{ nodeID: info.nodeID, address: info.address, }, - client: client, - refCnt: 1, - } - return ret + refCnt: 1, + pooling: true, + creator: creator, + }, nil } -func newPoolingShardClient(info *nodeInfo, creator queryNodeCreatorFunc) (*shardClient, error) { +func (n *shardClient) initClients() error { num := paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt() if num <= 0 { num = 1 } clients := make([]types.QueryNodeClient, 0, num) for i := 0; i < num; i++ { - client, err := creator(context.Background(), info.address, info.nodeID) + client, err := n.creator(context.Background(), n.info.address, n.info.nodeID) if err != nil { - return nil, err + // roll back already created clients + for _, c := range clients[:i] { + c.Close() + } + return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.nodeID)) } clients = append(clients, client) } - return &shardClient{ - info: nodeInfo{ - nodeID: info.nodeID, - address: info.address, - }, - refCnt: 1, - pooling: true, - clients: clients, - poolSize: num, - }, nil + + n.clients = clients + n.poolSize = num + return nil } type shardClientMgr interface {