From d29e01e2847928bf22d7681bd3faed9057e6aad3 Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Thu, 26 Sep 2024 21:23:16 +0800 Subject: [PATCH] fix: port listen racing in mix or standalone mode (#36442) issue: #36441 --------- Signed-off-by: chyezh --- cmd/components/data_coord.go | 10 +- cmd/components/data_node.go | 4 + cmd/components/index_coord.go | 4 + cmd/components/index_node.go | 4 + cmd/components/proxy.go | 4 + cmd/components/query_coord.go | 4 + cmd/components/query_node.go | 4 + cmd/components/root_coord.go | 4 + cmd/roles/roles.go | 4 + internal/distributed/datacoord/service.go | 51 +++-- .../distributed/datacoord/service_test.go | 38 ++- internal/distributed/datanode/service.go | 64 +++--- internal/distributed/datanode/service_test.go | 6 + internal/distributed/indexnode/service.go | 50 ++-- .../distributed/indexnode/service_test.go | 2 + .../distributed/proxy/listener_manager.go | 216 ++++++++++++++++++ internal/distributed/proxy/service.go | 179 +++------------ internal/distributed/proxy/service_test.go | 16 +- internal/distributed/querycoord/service.go | 60 +++-- .../distributed/querycoord/service_test.go | 2 + internal/distributed/querynode/service.go | 66 +++--- .../distributed/querynode/service_test.go | 4 + internal/distributed/rootcoord/service.go | 66 +++--- .../distributed/rootcoord/service_test.go | 30 ++- internal/distributed/streamingnode/service.go | 72 +++--- internal/proxy/proxy_test.go | 38 ++- pkg/util/netutil/listener.go | 147 ++++++++++++ pkg/util/netutil/listener_test.go | 38 +++ .../search_after_coord_down_test.go | 25 +- .../coordrecovery/coord_recovery_test.go | 26 +-- tests/integration/datanode/compaction_test.go | 7 +- tests/integration/minicluster_v2.go | 133 ++++++----- tests/integration/target/target_test.go | 13 +- 33 files changed, 906 insertions(+), 485 deletions(-) create mode 100644 internal/distributed/proxy/listener_manager.go create mode 100644 pkg/util/netutil/listener.go create mode 100644 pkg/util/netutil/listener_test.go diff --git a/cmd/components/data_coord.go b/cmd/components/data_coord.go index 977a52a42dece..6a031174d829d 100644 --- a/cmd/components/data_coord.go +++ b/cmd/components/data_coord.go @@ -39,7 +39,10 @@ type DataCoord struct { // NewDataCoord creates a new DataCoord func NewDataCoord(ctx context.Context, factory dependency.Factory) (*DataCoord, error) { - s := grpcdatacoordclient.NewServer(ctx, factory) + s, err := grpcdatacoordclient.NewServer(ctx, factory) + if err != nil { + return nil, err + } return &DataCoord{ ctx: ctx, @@ -47,6 +50,11 @@ func NewDataCoord(ctx context.Context, factory dependency.Factory) (*DataCoord, }, nil } +// Prepare prepares service +func (s *DataCoord) Prepare() error { + return s.svr.Prepare() +} + // Run starts service func (s *DataCoord) Run() error { if err := s.svr.Run(); err != nil { diff --git a/cmd/components/data_node.go b/cmd/components/data_node.go index 8fbba83a0800d..f2830bcbf9549 100644 --- a/cmd/components/data_node.go +++ b/cmd/components/data_node.go @@ -50,6 +50,10 @@ func NewDataNode(ctx context.Context, factory dependency.Factory) (*DataNode, er }, nil } +func (d *DataNode) Prepare() error { + return d.svr.Prepare() +} + // Run starts service func (d *DataNode) Run() error { if err := d.svr.Run(); err != nil { diff --git a/cmd/components/index_coord.go b/cmd/components/index_coord.go index ff03f83789318..2d03c1d08fadf 100644 --- a/cmd/components/index_coord.go +++ b/cmd/components/index_coord.go @@ -33,6 +33,10 @@ func NewIndexCoord(ctx context.Context, factory dependency.Factory) (*IndexCoord return &IndexCoord{}, nil } +func (s *IndexCoord) Prepare() error { + return nil +} + // Run starts service func (s *IndexCoord) Run() error { log.Info("IndexCoord running ...") diff --git a/cmd/components/index_node.go b/cmd/components/index_node.go index edf72384d4d2d..0d4f6496f76a7 100644 --- a/cmd/components/index_node.go +++ b/cmd/components/index_node.go @@ -48,6 +48,10 @@ func NewIndexNode(ctx context.Context, factory dependency.Factory) (*IndexNode, return n, nil } +func (n *IndexNode) Prepare() error { + return n.svr.Prepare() +} + // Run starts service func (n *IndexNode) Run() error { if err := n.svr.Run(); err != nil { diff --git a/cmd/components/proxy.go b/cmd/components/proxy.go index cb74b36680a90..5fcda443f85b2 100644 --- a/cmd/components/proxy.go +++ b/cmd/components/proxy.go @@ -49,6 +49,10 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { return n, nil } +func (n *Proxy) Prepare() error { + return n.svr.Prepare() +} + // Run starts service func (n *Proxy) Run() error { if err := n.svr.Run(); err != nil { diff --git a/cmd/components/query_coord.go b/cmd/components/query_coord.go index c98812d86ef62..f796e7c9e3523 100644 --- a/cmd/components/query_coord.go +++ b/cmd/components/query_coord.go @@ -50,6 +50,10 @@ func NewQueryCoord(ctx context.Context, factory dependency.Factory) (*QueryCoord }, nil } +func (qs *QueryCoord) Prepare() error { + return qs.svr.Prepare() +} + // Run starts service func (qs *QueryCoord) Run() error { if err := qs.svr.Run(); err != nil { diff --git a/cmd/components/query_node.go b/cmd/components/query_node.go index 3857f81bafa49..325fd77c72a71 100644 --- a/cmd/components/query_node.go +++ b/cmd/components/query_node.go @@ -50,6 +50,10 @@ func NewQueryNode(ctx context.Context, factory dependency.Factory) (*QueryNode, }, nil } +func (q *QueryNode) Prepare() error { + return q.svr.Prepare() +} + // Run starts service func (q *QueryNode) Run() error { if err := q.svr.Run(); err != nil { diff --git a/cmd/components/root_coord.go b/cmd/components/root_coord.go index e130516ac8d16..24040bc94a8ca 100644 --- a/cmd/components/root_coord.go +++ b/cmd/components/root_coord.go @@ -49,6 +49,10 @@ func NewRootCoord(ctx context.Context, factory dependency.Factory) (*RootCoord, }, nil } +func (rc *RootCoord) Prepare() error { + return rc.svr.Prepare() +} + // Run starts service func (rc *RootCoord) Run() error { if err := rc.svr.Run(); err != nil { diff --git a/cmd/roles/roles.go b/cmd/roles/roles.go index f5073df7c07d4..b1db0682e1d60 100644 --- a/cmd/roles/roles.go +++ b/cmd/roles/roles.go @@ -78,6 +78,7 @@ func stopRocksmq() { type component interface { healthz.Indicator + Prepare() error Run() error Stop() error } @@ -121,6 +122,9 @@ func runComponent[T component](ctx context.Context, if err != nil { panic(err) } + if err := role.Prepare(); err != nil { + panic(err) + } close(sign) if err := role.Run(); err != nil { panic(err) diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index 8d7271c2562e3..998bb21106ab7 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -19,8 +19,6 @@ package grpcdatacoord import ( "context" - "net" - "strconv" "sync" "time" @@ -51,6 +49,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/netutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tikv" ) @@ -70,23 +69,37 @@ type Server struct { grpcErrChan chan error grpcServer *grpc.Server + listener *netutil.NetListener } // NewServer new data service grpc server -func NewServer(ctx context.Context, factory dependency.Factory, opts ...datacoord.Option) *Server { +func NewServer(ctx context.Context, factory dependency.Factory, opts ...datacoord.Option) (*Server, error) { ctx1, cancel := context.WithCancel(ctx) - s := &Server{ ctx: ctx1, cancel: cancel, grpcErrChan: make(chan error), } s.dataCoord = datacoord.CreateServer(s.ctx, factory, opts...) - return s + return s, nil } var getTiKVClient = tikv.GetTiKVClient +func (s *Server) Prepare() error { + listener, err := netutil.NewListener( + netutil.OptIP(paramtable.Get().DataCoordGrpcServerCfg.IP), + netutil.OptPort(paramtable.Get().DataCoordGrpcServerCfg.Port.GetAsInt()), + ) + if err != nil { + log.Warn("DataCoord fail to create net listener", zap.Error(err)) + return err + } + log.Info("DataCoord listen on", zap.String("address", listener.Addr().String()), zap.Int("port", listener.Port())) + s.listener = listener + return nil +} + func (s *Server) init() error { params := paramtable.Get() etcdConfig := ¶ms.EtcdCfg @@ -108,7 +121,7 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.dataCoord.SetEtcdClient(etcdCli) - s.dataCoord.SetAddress(params.DataCoordGrpcServerCfg.GetAddress()) + s.dataCoord.SetAddress(s.listener.Address()) if params.MetaStoreCfg.MetaStoreType.GetValue() == util.MetaStoreTypeTiKV { log.Info("Connecting to tikv metadata storage.") @@ -135,26 +148,18 @@ func (s *Server) init() error { } func (s *Server) startGrpc() error { - Params := ¶mtable.Get().DataCoordGrpcServerCfg s.grpcWG.Add(1) - go s.startGrpcLoop(Params.Port.GetAsInt()) + go s.startGrpcLoop() // wait for grpc server loop start err := <-s.grpcErrChan return err } -func (s *Server) startGrpcLoop(grpcPort int) { +func (s *Server) startGrpcLoop() { defer logutil.LogPanic() defer s.grpcWG.Done() Params := ¶mtable.Get().DataCoordGrpcServerCfg - log.Debug("network port", zap.Int("port", grpcPort)) - lis, err := net.Listen("tcp", ":"+strconv.Itoa(grpcPort)) - if err != nil { - log.Error("grpc server failed to listen error", zap.Error(err)) - s.grpcErrChan <- err - return - } ctx, cancel := context.WithCancel(s.ctx) defer cancel() @@ -204,7 +209,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { s.dataCoord.RegisterStreamingCoordGRPCService(s.grpcServer) } go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) - if err := s.grpcServer.Serve(lis); err != nil { + if err := s.grpcServer.Serve(s.listener); err != nil { s.grpcErrChan <- err } } @@ -227,8 +232,10 @@ func (s *Server) start() error { // Stop stops the DataCoord server gracefully. // Need to call the GracefulStop interface of grpc server and call the stop method of the inner DataCoord object. func (s *Server) Stop() (err error) { - Params := ¶mtable.Get().DataCoordGrpcServerCfg - logger := log.With(zap.String("address", Params.GetAddress())) + logger := log.With() + if s.listener != nil { + logger = log.With(zap.String("address", s.listener.Address())) + } logger.Info("Datacoord stopping") defer func() { logger.Info("Datacoord stopped", zap.Error(err)) @@ -251,8 +258,12 @@ func (s *Server) Stop() (err error) { log.Error("failed to close dataCoord", zap.Error(err)) return err } - s.cancel() + + // release the listener + if s.listener != nil { + s.listener.Close() + } return nil } diff --git a/internal/distributed/datacoord/service_test.go b/internal/distributed/datacoord/service_test.go index 955b7476ea4a9..fcad74a636b0f 100644 --- a/internal/distributed/datacoord/service_test.go +++ b/internal/distributed/datacoord/service_test.go @@ -41,7 +41,8 @@ func Test_NewServer(t *testing.T) { ctx := context.Background() mockDataCoord := mocks.NewMockDataCoord(t) - server := NewServer(ctx, nil) + server, err := NewServer(ctx, nil) + assert.NoError(t, err) assert.NotNil(t, server) server.dataCoord = mockDataCoord @@ -342,7 +343,8 @@ func Test_Run(t *testing.T) { defer func() { getTiKVClient = tikv.GetTiKVClient }() - server := NewServer(ctx, nil) + server, err := NewServer(ctx, nil) + assert.NoError(t, err) assert.NotNil(t, server) mockDataCoord := mocks.NewMockDataCoord(t) @@ -354,7 +356,9 @@ func Test_Run(t *testing.T) { mockDataCoord.EXPECT().Init().Return(nil) mockDataCoord.EXPECT().Start().Return(nil) mockDataCoord.EXPECT().Register().Return(nil) - err := server.Run() + err = server.Prepare() + assert.NoError(t, err) + err = server.Run() assert.NoError(t, err) mockDataCoord.EXPECT().Stop().Return(nil) @@ -367,15 +371,18 @@ func Test_Run(t *testing.T) { t.Run("test init error", func(t *testing.T) { ctx := context.Background() - server := NewServer(ctx, nil) + server, err := NewServer(ctx, nil) assert.NotNil(t, server) + assert.NoError(t, err) mockDataCoord := mocks.NewMockDataCoord(t) mockDataCoord.EXPECT().SetEtcdClient(mock.Anything) mockDataCoord.EXPECT().SetAddress(mock.Anything) mockDataCoord.EXPECT().Init().Return(errors.New("error")) server.dataCoord = mockDataCoord - err := server.Run() + err = server.Prepare() + assert.NoError(t, err) + err = server.Run() assert.Error(t, err) mockDataCoord.EXPECT().Stop().Return(nil) @@ -384,7 +391,8 @@ func Test_Run(t *testing.T) { t.Run("test register error", func(t *testing.T) { ctx := context.Background() - server := NewServer(ctx, nil) + server, err := NewServer(ctx, nil) + assert.NoError(t, err) assert.NotNil(t, server) mockDataCoord := mocks.NewMockDataCoord(t) mockDataCoord.EXPECT().SetEtcdClient(mock.Anything) @@ -393,7 +401,9 @@ func Test_Run(t *testing.T) { mockDataCoord.EXPECT().Register().Return(errors.New("error")) server.dataCoord = mockDataCoord - err := server.Run() + err = server.Prepare() + assert.NoError(t, err) + err = server.Run() assert.Error(t, err) mockDataCoord.EXPECT().Stop().Return(nil) @@ -402,7 +412,8 @@ func Test_Run(t *testing.T) { t.Run("test start error", func(t *testing.T) { ctx := context.Background() - server := NewServer(ctx, nil) + server, err := NewServer(ctx, nil) + assert.NoError(t, err) assert.NotNil(t, server) mockDataCoord := mocks.NewMockDataCoord(t) mockDataCoord.EXPECT().SetEtcdClient(mock.Anything) @@ -412,7 +423,9 @@ func Test_Run(t *testing.T) { mockDataCoord.EXPECT().Start().Return(errors.New("error")) server.dataCoord = mockDataCoord - err := server.Run() + err = server.Prepare() + assert.NoError(t, err) + err = server.Run() assert.Error(t, err) mockDataCoord.EXPECT().Stop().Return(nil) @@ -421,7 +434,8 @@ func Test_Run(t *testing.T) { t.Run("test stop error", func(t *testing.T) { ctx := context.Background() - server := NewServer(ctx, nil) + server, err := NewServer(ctx, nil) + assert.NoError(t, err) assert.NotNil(t, server) mockDataCoord := mocks.NewMockDataCoord(t) mockDataCoord.EXPECT().SetEtcdClient(mock.Anything) @@ -431,7 +445,9 @@ func Test_Run(t *testing.T) { mockDataCoord.EXPECT().Start().Return(nil) server.dataCoord = mockDataCoord - err := server.Run() + err = server.Prepare() + assert.NoError(t, err) + err = server.Run() assert.NoError(t, err) mockDataCoord.EXPECT().Stop().Return(errors.New("error")) diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 7d538d4c93738..5e4ae6f0095e9 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -18,8 +18,6 @@ package grpcdatanode import ( "context" - "fmt" - "net" "strconv" "sync" "time" @@ -50,8 +48,8 @@ import ( "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/netutil" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/retry" ) type Server struct { @@ -59,6 +57,7 @@ type Server struct { grpcWG sync.WaitGroup grpcErrChan chan error grpcServer *grpc.Server + listener *netutil.NetListener ctx context.Context cancel context.CancelFunc etcdCli *clientv3.Client @@ -66,9 +65,6 @@ type Server struct { serverID atomic.Int64 - rootCoord types.RootCoord - dataCoord types.DataCoord - newRootCoordClient func() (types.RootCoordClient, error) newDataCoordClient func() (types.DataCoordClient, error) } @@ -94,17 +90,33 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) return s, nil } +func (s *Server) Prepare() error { + listener, err := netutil.NewListener( + netutil.OptIP(paramtable.Get().DataNodeGrpcServerCfg.IP), + netutil.OptHighPriorityToUsePort(paramtable.Get().DataNodeGrpcServerCfg.Port.GetAsInt()), + ) + if err != nil { + log.Warn("DataNode fail to create net listener", zap.Error(err)) + return err + } + log.Info("DataNode listen on", zap.String("address", listener.Addr().String()), zap.Int("port", listener.Port())) + s.listener = listener + paramtable.Get().Save( + paramtable.Get().DataNodeGrpcServerCfg.Port.Key, + strconv.FormatInt(int64(listener.Port()), 10)) + return nil +} + func (s *Server) startGrpc() error { - Params := ¶mtable.Get().DataNodeGrpcServerCfg s.grpcWG.Add(1) - go s.startGrpcLoop(Params.Port.GetAsInt()) + go s.startGrpcLoop() // wait for grpc server loop start err := <-s.grpcErrChan return err } // startGrpcLoop starts the grep loop of datanode component. -func (s *Server) startGrpcLoop(grpcPort int) { +func (s *Server) startGrpcLoop() { defer s.grpcWG.Done() Params := ¶mtable.Get().DataNodeGrpcServerCfg kaep := keepalive.EnforcementPolicy{ @@ -116,19 +128,6 @@ func (s *Server) startGrpcLoop(grpcPort int) { Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } - var lis net.Listener - - err := retry.Do(s.ctx, func() error { - addr := ":" + strconv.Itoa(grpcPort) - var err error - lis, err = net.Listen("tcp", addr) - return err - }, retry.Attempts(10)) - if err != nil { - log.Error("DataNode GrpcServer:failed to listen", zap.Error(err)) - s.grpcErrChan <- err - return - } s.grpcServer = grpc.NewServer( grpc.KeepaliveEnforcementPolicy(kaep), @@ -162,7 +161,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { defer cancel() go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) - if err := s.grpcServer.Serve(lis); err != nil { + if err := s.grpcServer.Serve(s.listener); err != nil { log.Warn("DataNode failed to start gRPC") s.grpcErrChan <- err } @@ -197,8 +196,10 @@ func (s *Server) Run() error { // Stop stops Datanode's grpc service. func (s *Server) Stop() (err error) { - Params := ¶mtable.Get().DataNodeGrpcServerCfg - logger := log.With(zap.String("address", Params.GetAddress())) + logger := log.With() + if s.listener != nil { + logger = log.With(zap.String("address", s.listener.Address())) + } logger.Info("datanode stopping") defer func() { logger.Info("datanode stopped", zap.Error(err)) @@ -219,18 +220,17 @@ func (s *Server) Stop() (err error) { return err } s.cancel() + + if s.listener != nil { + s.listener.Close() + } return nil } // init initializes Datanode's grpc service. func (s *Server) init() error { etcdConfig := ¶mtable.Get().EtcdCfg - Params := ¶mtable.Get().DataNodeGrpcServerCfg ctx := context.Background() - if !funcutil.CheckPortAvailable(Params.Port.GetAsInt()) { - paramtable.Get().Save(Params.Port.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort())) - log.Warn("DataNode found available port during init", zap.Int("port", Params.Port.GetAsInt())) - } etcdCli, err := etcd.CreateEtcdClient( etcdConfig.UseEmbedEtcd.GetAsBool(), @@ -249,8 +249,8 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.SetEtcdClient(s.etcdCli) - s.datanode.SetAddress(Params.GetAddress()) - log.Info("DataNode address", zap.String("address", Params.IP+":"+strconv.Itoa(Params.Port.GetAsInt()))) + s.datanode.SetAddress(s.listener.Address()) + log.Info("DataNode address", zap.String("address", s.listener.Address())) log.Info("DataNode serverID", zap.Int64("serverID", s.serverID.Load())) err = s.startGrpc() diff --git a/internal/distributed/datanode/service_test.go b/internal/distributed/datanode/service_test.go index 66390ae4fcb13..640ee87e28916 100644 --- a/internal/distributed/datanode/service_test.go +++ b/internal/distributed/datanode/service_test.go @@ -227,6 +227,8 @@ func Test_NewServer(t *testing.T) { t.Run("Run", func(t *testing.T) { server.datanode = &MockDataNode{} + err = server.Prepare() + assert.NoError(t, err) err = server.Run() assert.NoError(t, err) }) @@ -335,6 +337,8 @@ func Test_NewServer(t *testing.T) { } func Test_Run(t *testing.T) { + paramtable.Init() + ctx := context.Background() server, err := NewServer(ctx, nil) assert.NoError(t, err) @@ -376,6 +380,8 @@ func Test_Run(t *testing.T) { regErr: errors.New("error"), } + err = server.Prepare() + assert.NoError(t, err) err = server.Run() assert.Error(t, err) diff --git a/internal/distributed/indexnode/service.go b/internal/distributed/indexnode/service.go index e888106f614fa..403343ee907c4 100644 --- a/internal/distributed/indexnode/service.go +++ b/internal/distributed/indexnode/service.go @@ -18,8 +18,6 @@ package grpcindexnode import ( "context" - "fmt" - "net" "strconv" "sync" "time" @@ -46,6 +44,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/netutil" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -54,6 +53,7 @@ type Server struct { indexnode types.IndexNodeComponent grpcServer *grpc.Server + listener *netutil.NetListener grpcErrChan chan error serverID atomic.Int64 @@ -65,6 +65,23 @@ type Server struct { etcdCli *clientv3.Client } +func (s *Server) Prepare() error { + listener, err := netutil.NewListener( + netutil.OptIP(paramtable.Get().IndexNodeGrpcServerCfg.IP), + netutil.OptHighPriorityToUsePort(paramtable.Get().IndexNodeGrpcServerCfg.Port.GetAsInt()), + ) + if err != nil { + log.Warn("IndexNode fail to create net listener", zap.Error(err)) + return err + } + s.listener = listener + log.Info("IndexNode listen on", zap.String("address", listener.Addr().String()), zap.Int("port", listener.Port())) + paramtable.Get().Save( + paramtable.Get().IndexNodeGrpcServerCfg.Port.Key, + strconv.FormatInt(int64(listener.Port()), 10)) + return nil +} + // Run initializes and starts IndexNode's grpc service. func (s *Server) Run() error { if err := s.init(); err != nil { @@ -79,17 +96,10 @@ func (s *Server) Run() error { } // startGrpcLoop starts the grep loop of IndexNode component. -func (s *Server) startGrpcLoop(grpcPort int) { +func (s *Server) startGrpcLoop() { defer s.grpcWG.Done() Params := ¶mtable.Get().IndexNodeGrpcServerCfg - log.Debug("IndexNode", zap.String("network address", Params.GetAddress()), zap.Int("network port: ", grpcPort)) - lis, err := net.Listen("tcp", ":"+strconv.Itoa(grpcPort)) - if err != nil { - log.Warn("IndexNode", zap.Error(err)) - s.grpcErrChan <- err - return - } ctx, cancel := context.WithCancel(s.loopCtx) defer cancel() @@ -132,7 +142,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { grpc.StatsHandler(tracer.GetDynamicOtelGrpcServerStatsHandler())) workerpb.RegisterIndexNodeServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) - if err := s.grpcServer.Serve(lis); err != nil { + if err := s.grpcServer.Serve(s.listener); err != nil { s.grpcErrChan <- err } } @@ -140,12 +150,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { // init initializes IndexNode's grpc service. func (s *Server) init() error { etcdConfig := ¶mtable.Get().EtcdCfg - Params := ¶mtable.Get().IndexNodeGrpcServerCfg var err error - if !funcutil.CheckPortAvailable(Params.Port.GetAsInt()) { - paramtable.Get().Save(Params.Port.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort())) - log.Warn("IndexNode get available port when init", zap.Int("Port", Params.Port.GetAsInt())) - } defer func() { if err != nil { @@ -157,7 +162,7 @@ func (s *Server) init() error { }() s.grpcWG.Add(1) - go s.startGrpcLoop(Params.Port.GetAsInt()) + go s.startGrpcLoop() // wait for grpc server loop start err = <-s.grpcErrChan if err != nil { @@ -183,7 +188,7 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.indexnode.SetEtcdClient(etcdCli) - s.indexnode.SetAddress(Params.GetAddress()) + s.indexnode.SetAddress(s.listener.Address()) err = s.indexnode.Init() if err != nil { log.Error("IndexNode Init failed", zap.Error(err)) @@ -210,8 +215,10 @@ func (s *Server) start() error { // Stop stops IndexNode's grpc service. func (s *Server) Stop() (err error) { - Params := ¶mtable.Get().IndexNodeGrpcServerCfg - logger := log.With(zap.String("address", Params.GetAddress())) + logger := log.With() + if s.listener != nil { + logger = log.With(zap.String("address", s.listener.Address())) + } logger.Info("IndexNode stopping") defer func() { logger.Info("IndexNode stopped", zap.Error(err)) @@ -233,6 +240,9 @@ func (s *Server) Stop() (err error) { s.grpcWG.Wait() s.loopCancel() + if s.listener != nil { + s.listener.Close() + } return nil } diff --git a/internal/distributed/indexnode/service_test.go b/internal/distributed/indexnode/service_test.go index a8c56e73d749f..9caea2c2663a8 100644 --- a/internal/distributed/indexnode/service_test.go +++ b/internal/distributed/indexnode/service_test.go @@ -52,6 +52,8 @@ func TestIndexNodeServer(t *testing.T) { err = server.setServer(inm) assert.NoError(t, err) + err = server.Prepare() + assert.NoError(t, err) err = server.Run() assert.NoError(t, err) diff --git a/internal/distributed/proxy/listener_manager.go b/internal/distributed/proxy/listener_manager.go new file mode 100644 index 0000000000000..db25119e01fed --- /dev/null +++ b/internal/distributed/proxy/listener_manager.go @@ -0,0 +1,216 @@ +package grpcproxy + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net" + + "github.com/cockroachdb/errors" + "github.com/soheilhy/cmux" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/netutil" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// newListenerManager creates a new listener +func newListenerManager() (l *listenerManager, err error) { + defer func() { + if err != nil && l != nil { + l.Close() + } + }() + + externalGrpcListener, err := netutil.NewListener( + netutil.OptIP(paramtable.Get().ProxyGrpcServerCfg.IP), + netutil.OptPort(paramtable.Get().ProxyGrpcServerCfg.Port.GetAsInt()), + ) + if err != nil { + log.Warn("Proxy fail to create external grpc listener", zap.Error(err)) + return + } + log.Info("Proxy listen on external grpc listener", zap.String("address", externalGrpcListener.Address()), zap.Int("port", externalGrpcListener.Port())) + + internalGrpcListener, err := netutil.NewListener( + netutil.OptIP(paramtable.Get().ProxyGrpcServerCfg.IP), + netutil.OptPort(paramtable.Get().ProxyGrpcServerCfg.InternalPort.GetAsInt()), + ) + if err != nil { + log.Warn("Proxy fail to create internal grpc listener", zap.Error(err)) + return + } + log.Info("Proxy listen on internal grpc listener", zap.String("address", internalGrpcListener.Address()), zap.Int("port", internalGrpcListener.Port())) + + l = &listenerManager{ + externalGrpcListener: externalGrpcListener, + internalGrpcListener: internalGrpcListener, + } + if err = newHTTPListner(l); err != nil { + return + } + return +} + +// newHTTPListner creates a new http listener +func newHTTPListner(l *listenerManager) error { + HTTPParams := ¶mtable.Get().HTTPCfg + if !HTTPParams.Enabled.GetAsBool() { + // http server is disabled + log.Info("Proxy server(http) is disabled, skip initialize http listener") + return nil + } + tlsMode := paramtable.Get().ProxyGrpcServerCfg.TLSMode.GetAsInt() + if tlsMode != 0 && tlsMode != 1 && tlsMode != 2 { + return errors.New("tls mode must be 0: no authentication, 1: one way authentication or 2: two way authentication") + } + + httpPortString := HTTPParams.Port.GetValue() + httpPort := HTTPParams.Port.GetAsInt() + externGrpcPort := l.externalGrpcListener.Port() + if len(httpPortString) == 0 || externGrpcPort == httpPort { + if tlsMode != 0 { + err := errors.New("proxy server(http) and external grpc server share the same port, tls mode must be 0") + log.Warn("can not initialize http listener", zap.Error(err)) + return err + } + log.Info("Proxy server(http) and external grpc server share the same port") + l.portShareMode = true + l.cmux = cmux.New(l.externalGrpcListener) + l.cmuxClosed = make(chan struct{}) + l.cmuxExternGrpcListener = l.cmux.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc")) + l.cmuxExternHTTPListener = l.cmux.Match(cmux.Any()) + go func() { + defer close(l.cmuxClosed) + if err := l.cmux.Serve(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Warn("Proxy cmux server closed", zap.Error(err)) + return + } + log.Info("Proxy tcp server exited") + }() + return nil + } + + Params := ¶mtable.Get().ProxyGrpcServerCfg + var tlsConf *tls.Config + switch tlsMode { + case 1: + creds, err := tls.LoadX509KeyPair(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue()) + if err != nil { + log.Error("proxy can't create creds", zap.Error(err)) + return err + } + tlsConf = &tls.Config{Certificates: []tls.Certificate{creds}} + case 2: + cert, err := tls.LoadX509KeyPair(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue()) + if err != nil { + log.Error("proxy cant load x509 key pair", zap.Error(err)) + return err + } + certPool := x509.NewCertPool() + rootBuf, err := storage.ReadFile(Params.CaPemPath.GetValue()) + if err != nil { + log.Error("failed read ca pem", zap.Error(err)) + return err + } + if !certPool.AppendCertsFromPEM(rootBuf) { + log.Warn("fail to append ca to cert") + return fmt.Errorf("fail to append ca to cert") + } + tlsConf = &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + Certificates: []tls.Certificate{cert}, + ClientCAs: certPool, + MinVersion: tls.VersionTLS13, + } + } + + var err error + l.portShareMode = false + l.httpListener, err = netutil.NewListener(netutil.OptIP(Params.IP), netutil.OptPort(httpPort), netutil.OptTLS(tlsConf)) + if err != nil { + log.Warn("Proxy server(http) failed to listen on", zap.Error(err)) + return err + } + log.Info("Proxy server(http) listen on", zap.Int("port", l.httpListener.Port())) + return nil +} + +type listenerManager struct { + externalGrpcListener *netutil.NetListener + internalGrpcListener *netutil.NetListener + + portShareMode bool + // portShareMode == true + cmux cmux.CMux + cmuxClosed chan struct{} + cmuxExternGrpcListener net.Listener + cmuxExternHTTPListener net.Listener + + // portShareMode == false + httpListener *netutil.NetListener +} + +func (l *listenerManager) ExternalGrpcListener() net.Listener { + if l.portShareMode { + return l.cmuxExternGrpcListener + } + return l.externalGrpcListener +} + +func (l *listenerManager) InternalGrpcListener() net.Listener { + return l.internalGrpcListener +} + +func (l *listenerManager) HTTPListener() net.Listener { + if l.portShareMode { + return l.cmuxExternHTTPListener + } + // httpListener maybe nil if http server is disabled + if l.httpListener == nil { + return nil + } + return l.httpListener +} + +func (l *listenerManager) Close() { + if l.portShareMode { + if l.cmux != nil { + log.Info("Proxy close cmux grpc listener") + if err := l.cmuxExternGrpcListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Warn("Proxy failed to close cmux grpc listener", zap.Error(err)) + } + log.Info("Proxy close cmux http listener") + if err := l.cmuxExternHTTPListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Warn("Proxy failed to close cmux http listener", zap.Error(err)) + } + log.Info("Proxy close cmux...") + l.cmux.Close() + <-l.cmuxClosed + log.Info("Proxy cmux closed") + } + } else { + if l.httpListener != nil { + log.Info("Proxy close http listener", zap.String("address", l.httpListener.Address())) + if err := l.httpListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Warn("Proxy failed to close http listener", zap.Error(err)) + } + } + } + + if l.internalGrpcListener != nil { + log.Info("Proxy close internal grpc listener", zap.String("address", l.internalGrpcListener.Address())) + if err := l.internalGrpcListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Warn("Proxy failed to close internal grpc listener", zap.Error(err)) + } + } + + if l.externalGrpcListener != nil { + log.Info("Proxy close external grpc listener", zap.String("address", l.externalGrpcListener.Address())) + if err := l.externalGrpcListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + log.Warn("Proxy failed to close external grpc listener", zap.Error(err)) + } + } +} diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index 707b5dfcef1e6..b7f6e768d269f 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -22,7 +22,6 @@ import ( "crypto/x509" "fmt" "io" - "net" "net/http" "os" "strconv" @@ -30,7 +29,6 @@ import ( "sync" "time" - "github.com/cockroachdb/errors" "github.com/gin-gonic/gin" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" @@ -70,7 +68,6 @@ import ( "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/etcd" - "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -98,12 +95,10 @@ type Server struct { ctx context.Context wg sync.WaitGroup proxy types.ProxyComponent - httpListener net.Listener - grpcListener net.Listener - tcpServer cmux.CMux httpServer *http.Server grpcInternalServer *grpc.Server grpcExternalServer *grpc.Server + listenerManager *listenerManager serverID atomic.Int64 @@ -115,11 +110,11 @@ type Server struct { // NewServer create a Proxy server. func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) { - var err error server := &Server{ ctx: ctx, } + var err error server.proxy, err = proxy.NewProxy(server.ctx, factory) if err != nil { return nil, err @@ -259,7 +254,7 @@ func (s *Server) startHTTPServer(errChan chan error) { httpserver.NewHandlersV2(s.proxy).RegisterRoutesToV2(appV2) s.httpServer = &http.Server{Handler: ginHandler, ReadHeaderTimeout: time.Second} errChan <- nil - if err := s.httpServer.Serve(s.httpListener); err != nil && err != cmux.ErrServerClosed { + if err := s.httpServer.Serve(s.listenerManager.HTTPListener()); err != nil && err != cmux.ErrServerClosed { log.Error("start Proxy http server to listen failed", zap.Error(err)) errChan <- err return @@ -267,17 +262,17 @@ func (s *Server) startHTTPServer(errChan chan error) { log.Info("Proxy http server exited") } -func (s *Server) startInternalRPCServer(grpcInternalPort int, errChan chan error) { +func (s *Server) startInternalRPCServer(errChan chan error) { s.wg.Add(1) - go s.startInternalGrpc(grpcInternalPort, errChan) + go s.startInternalGrpc(errChan) } -func (s *Server) startExternalRPCServer(grpcExternalPort int, errChan chan error) { +func (s *Server) startExternalRPCServer(errChan chan error) { s.wg.Add(1) - go s.startExternalGrpc(grpcExternalPort, errChan) + go s.startExternalGrpc(errChan) } -func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { +func (s *Server) startExternalGrpc(errChan chan error) { defer s.wg.Done() Params := ¶mtable.Get().ProxyGrpcServerCfg kaep := keepalive.EnforcementPolicy{ @@ -292,11 +287,11 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { limiter, err := s.proxy.GetRateLimiter() if err != nil { - log.Error("Get proxy rate limiter failed", zap.Int("port", grpcPort), zap.Error(err)) + log.Error("Get proxy rate limiter failed", zap.Error(err)) errChan <- err return } - log.Debug("Get proxy rate limiter done", zap.Int("port", grpcPort)) + log.Debug("Get proxy rate limiter done") var unaryServerOption grpc.ServerOption if enableCustomInterceptor { @@ -387,7 +382,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { zap.Any("enforcement policy", kaep), zap.Any("server parameters", kasp)) - if err := s.grpcExternalServer.Serve(s.grpcListener); err != nil && err != cmux.ErrServerClosed { + if err := s.grpcExternalServer.Serve(s.listenerManager.ExternalGrpcListener()); err != nil && err != cmux.ErrServerClosed { log.Error("failed to serve on Proxy's listener", zap.Error(err)) errChan <- err return @@ -395,7 +390,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { log.Info("Proxy external grpc server exited") } -func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) { +func (s *Server) startInternalGrpc(errChan chan error) { defer s.wg.Done() Params := ¶mtable.Get().ProxyGrpcServerCfg kaep := keepalive.EnforcementPolicy{ @@ -408,15 +403,6 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) { Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } - log.Info("Proxy internal server listen on tcp", zap.Int("port", grpcPort)) - lis, err := net.Listen("tcp", ":"+strconv.Itoa(grpcPort)) - if err != nil { - log.Warn("Proxy internal server failed to listen on", zap.Error(err), zap.Int("port", grpcPort)) - errChan <- err - return - } - log.Info("Proxy internal server already listen on tcp", zap.Int("port", grpcPort)) - opts := tracer.GetInterceptorOpts() s.grpcInternalServer = grpc.NewServer( grpc.KeepaliveEnforcementPolicy(kaep), @@ -451,7 +437,7 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) { zap.Any("enforcement policy", kaep), zap.Any("server parameters", kasp)) - if err := s.grpcInternalServer.Serve(lis); err != nil { + if err := s.grpcInternalServer.Serve(s.listenerManager.InternalGrpcListener()); err != nil { log.Error("failed to internal serve on Proxy's listener", zap.Error(err)) errChan <- err return @@ -459,6 +445,15 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) { log.Info("Proxy internal grpc server exited") } +func (s *Server) Prepare() error { + listenerManager, err := newListenerManager() + if err != nil { + return err + } + s.listenerManager = listenerManager + return nil +} + // Start start the Proxy Server func (s *Server) Run() error { log.Info("init Proxy server") @@ -474,18 +469,6 @@ func (s *Server) Run() error { return err } log.Info("start Proxy server done") - - if s.tcpServer != nil { - s.wg.Add(1) - go func() { - defer s.wg.Done() - if err := s.tcpServer.Serve(); err != nil && !errors.Is(err, net.ErrClosed) { - log.Warn("Proxy server for tcp port failed", zap.Error(err)) - return - } - log.Info("Proxy tcp server exited") - }() - } return nil } @@ -496,16 +479,6 @@ func (s *Server) init() error { HTTPParams := ¶mtable.Get().HTTPCfg log.Info("Proxy init http server's parameter table done") - if !funcutil.CheckPortAvailable(Params.Port.GetAsInt()) { - paramtable.Get().Save(Params.Port.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort())) - log.Warn("Proxy get available port when init", zap.Int("Port", Params.Port.GetAsInt())) - } - - log.Info("init Proxy's parameter table done", - zap.String("internalAddress", Params.GetInternalAddress()), - zap.String("externalAddress", Params.GetAddress()), - ) - accesslog.InitAccessLogger(paramtable.Get()) serviceName := fmt.Sprintf("Proxy ip: %s, port: %d", Params.IP, Params.Port.GetAsInt()) log.Info("init Proxy's tracer done", zap.String("service name", serviceName)) @@ -527,95 +500,18 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.proxy.SetEtcdClient(s.etcdCli) - s.proxy.SetAddress(Params.GetInternalAddress()) + s.proxy.SetAddress(s.listenerManager.internalGrpcListener.Address()) errChan := make(chan error, 1) { - s.startInternalRPCServer(Params.InternalPort.GetAsInt(), errChan) + s.startInternalRPCServer(errChan) if err := <-errChan; err != nil { log.Error("failed to create internal rpc server", zap.Error(err)) return err } } { - port := Params.Port.GetAsInt() - httpPort := HTTPParams.Port.GetAsInt() - log.Info("Proxy server listen on tcp", zap.Int("port", port)) - var lis net.Listener - - log.Info("Proxy server already listen on tcp", zap.Int("port", httpPort)) - lis, err = net.Listen("tcp", ":"+strconv.Itoa(port)) - if err != nil { - log.Error("Proxy server(grpc/http) failed to listen on", zap.Int("port", port), zap.Error(err)) - return err - } - - if HTTPParams.Enabled.GetAsBool() && - Params.TLSMode.GetAsInt() == 0 && - (HTTPParams.Port.GetValue() == "" || httpPort == port) { - s.tcpServer = cmux.New(lis) - s.grpcListener = s.tcpServer.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc")) - s.httpListener = s.tcpServer.Match(cmux.Any()) - } else { - s.grpcListener = lis - } - - if HTTPParams.Enabled.GetAsBool() && - HTTPParams.Port.GetValue() != "" && - httpPort != port { - if Params.TLSMode.GetAsInt() == 0 { - s.httpListener, err = net.Listen("tcp", ":"+strconv.Itoa(httpPort)) - if err != nil { - log.Error("Proxy server(grpc/http) failed to listen on", zap.Int("port", port), zap.Error(err)) - return err - } - } else if Params.TLSMode.GetAsInt() == 1 { - creds, err := tls.LoadX509KeyPair(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue()) - if err != nil { - log.Error("proxy can't create creds", zap.Error(err)) - return err - } - s.httpListener, err = tls.Listen("tcp", ":"+strconv.Itoa(httpPort), &tls.Config{ - Certificates: []tls.Certificate{creds}, - }) - if err != nil { - log.Error("Proxy server(grpc/http) failed to listen on", zap.Int("port", port), zap.Error(err)) - return err - } - } else if Params.TLSMode.GetAsInt() == 2 { - cert, err := tls.LoadX509KeyPair(Params.ServerPemPath.GetValue(), Params.ServerKeyPath.GetValue()) - if err != nil { - log.Error("proxy cant load x509 key pair", zap.Error(err)) - return err - } - - certPool := x509.NewCertPool() - rootBuf, err := storage.ReadFile(Params.CaPemPath.GetValue()) - if err != nil { - log.Error("failed read ca pem", zap.Error(err)) - return err - } - if !certPool.AppendCertsFromPEM(rootBuf) { - log.Warn("fail to append ca to cert") - return fmt.Errorf("fail to append ca to cert") - } - - tlsConf := &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - Certificates: []tls.Certificate{cert}, - ClientCAs: certPool, - MinVersion: tls.VersionTLS13, - } - s.httpListener, err = tls.Listen("tcp", ":"+strconv.Itoa(httpPort), tlsConf) - if err != nil { - log.Error("Proxy server(grpc/http) failed to listen on", zap.Int("port", port), zap.Error(err)) - return err - } - } - } - } - { - s.startExternalRPCServer(Params.Port.GetAsInt(), errChan) + s.startExternalRPCServer(errChan) if err := <-errChan; err != nil { log.Error("failed to create external rpc server", zap.Error(err)) return err @@ -722,7 +618,7 @@ func (s *Server) start() error { return err } - if s.httpListener != nil { + if s.listenerManager.HTTPListener() != nil { log.Info("start Proxy http server") errChan := make(chan error, 1) s.wg.Add(1) @@ -738,8 +634,12 @@ func (s *Server) start() error { // Stop stop the Proxy Server func (s *Server) Stop() (err error) { - Params := ¶mtable.Get().ProxyGrpcServerCfg - logger := log.With(zap.String("internal address", Params.GetInternalAddress()), zap.String("external address", Params.GetInternalAddress())) + logger := log.With() + if s.listenerManager != nil { + logger = log.With( + zap.String("internal address", s.listenerManager.internalGrpcListener.Address()), + zap.String("external address", s.listenerManager.externalGrpcListener.Address())) + } logger.Info("Proxy stopping") defer func() { logger.Info("Proxy stopped", zap.Error(err)) @@ -767,21 +667,14 @@ func (s *Server) Stop() (err error) { s.httpServer.Close() } - // close cmux server, it isn't a synchronized operation. - // Note that: - // 1. all listeners can be closed after closing cmux server that has the root listener, it will automatically - // propagate the closure to all the listeners derived from it, but it doesn't provide a graceful shutdown - // grpc server ideally. - // 2. avoid resource leak also need to close cmux after grpc and http listener closed. - if s.tcpServer != nil { - log.Info("Proxy stop tcp server...") - s.tcpServer.Close() - } - if s.grpcInternalServer != nil { log.Info("Proxy stop internal grpc server") utils.GracefulStopGRPCServer(s.grpcInternalServer) } + + if s.listenerManager != nil { + s.listenerManager.Close() + } }() gracefulWg.Wait() diff --git a/internal/distributed/proxy/service_test.go b/internal/distributed/proxy/service_test.go index 936785a88d27e..bb73295921810 100644 --- a/internal/distributed/proxy/service_test.go +++ b/internal/distributed/proxy/service_test.go @@ -177,6 +177,9 @@ func waitForServerReady() { } func runAndWaitForServerReady(server *Server) error { + if err := server.Prepare(); err != nil { + return err + } err := server.Run() if err != nil { return err @@ -1036,7 +1039,8 @@ func Test_NewHTTPServer_TLS_TwoWay(t *testing.T) { paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "19529") err = runAndWaitForServerReady(server) assert.NotNil(t, err) - server.Stop() + err = server.Stop() + assert.Nil(t, err) } func Test_NewHTTPServer_TLS_OneWay(t *testing.T) { @@ -1064,12 +1068,14 @@ func Test_NewHTTPServer_TLS_OneWay(t *testing.T) { paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "8080") err := runAndWaitForServerReady(server) + fmt.Printf("err: %v\n", err) assert.Nil(t, err) assert.NotNil(t, server.grpcExternalServer) err = server.Stop() assert.Nil(t, err) paramtable.Get().Save(proxy.Params.HTTPCfg.Port.Key, "19529") + fmt.Printf("err: %v\n", err) err = runAndWaitForServerReady(server) assert.NotNil(t, err) server.Stop() @@ -1080,8 +1086,8 @@ func Test_NewHTTPServer_TLS_FileNotExisted(t *testing.T) { mockProxy := server.proxy.(*mocks.MockProxy) mockProxy.EXPECT().Stop().Return(nil) - mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return() - mockProxy.EXPECT().SetAddress(mock.Anything).Return() + mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return().Maybe() + mockProxy.EXPECT().SetAddress(mock.Anything).Return().Maybe() Params := ¶mtable.Get().ProxyGrpcServerCfg paramtable.Get().Save(Params.TLSMode.Key, "1") @@ -1220,7 +1226,9 @@ func Test_Service_GracefulStop(t *testing.T) { enableRegisterProxyServer = false }() - err := server.Run() + err := server.Prepare() + assert.Nil(t, err) + err = server.Run() assert.Nil(t, err) proxyClient, err := grpcproxyclient.NewClient(ctx, fmt.Sprintf("localhost:%s", Params.Port.GetValue()), 0) diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index c0cc568b0d701..25b903c4edc9c 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -18,8 +18,6 @@ package grpcquerycoord import ( "context" - "net" - "strconv" "sync" "time" @@ -50,6 +48,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/netutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tikv" ) @@ -60,6 +59,7 @@ type Server struct { loopCtx context.Context loopCancel context.CancelFunc grpcServer *grpc.Server + listener *netutil.NetListener serverID atomic.Int64 @@ -94,17 +94,31 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) }, nil } +func (s *Server) Prepare() error { + listener, err := netutil.NewListener( + netutil.OptIP(paramtable.Get().QueryCoordGrpcServerCfg.IP), + netutil.OptPort(paramtable.Get().QueryCoordGrpcServerCfg.Port.GetAsInt()), + ) + if err != nil { + log.Warn("QueryCoord fail to create net listener", zap.Error(err)) + return err + } + log.Info("QueryCoord listen on", zap.String("address", listener.Addr().String()), zap.Int("port", listener.Port())) + s.listener = listener + return nil +} + // Run initializes and starts QueryCoord's grpc service. func (s *Server) Run() error { if err := s.init(); err != nil { return err } - log.Debug("QueryCoord init done ...") + log.Info("QueryCoord init done ...") if err := s.start(); err != nil { return err } - log.Debug("QueryCoord start done ...") + log.Info("QueryCoord start done ...") return nil } @@ -114,7 +128,6 @@ var getTiKVClient = tikv.GetTiKVClient func (s *Server) init() error { params := paramtable.Get() etcdConfig := ¶ms.EtcdCfg - rpcParams := ¶ms.QueryCoordGrpcServerCfg etcdCli, err := etcd.CreateEtcdClient( etcdConfig.UseEmbedEtcd.GetAsBool(), @@ -128,12 +141,12 @@ func (s *Server) init() error { etcdConfig.EtcdTLSCACert.GetValue(), etcdConfig.EtcdTLSMinVersion.GetValue()) if err != nil { - log.Debug("QueryCoord connect to etcd failed", zap.Error(err)) + log.Warn("QueryCoord connect to etcd failed", zap.Error(err)) return err } s.etcdCli = etcdCli s.SetEtcdClient(etcdCli) - s.queryCoord.SetAddress(rpcParams.GetAddress()) + s.queryCoord.SetAddress(s.listener.Address()) if params.MetaStoreCfg.MetaStoreType.GetValue() == util.MetaStoreTypeTiKV { log.Info("Connecting to tikv metadata storage.") @@ -147,7 +160,7 @@ func (s *Server) init() error { } s.grpcWG.Add(1) - go s.startGrpcLoop(rpcParams.Port.GetAsInt()) + go s.startGrpcLoop() // wait for grpc server loop start err = <-s.grpcErrChan if err != nil { @@ -164,7 +177,7 @@ func (s *Server) init() error { } // wait for master init or healthy - log.Debug("QueryCoord try to wait for RootCoord ready") + log.Info("QueryCoord try to wait for RootCoord ready") err = componentutil.WaitForComponentHealthy(s.loopCtx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200) if err != nil { log.Error("QueryCoord wait for RootCoord ready failed", zap.Error(err)) @@ -174,7 +187,7 @@ func (s *Server) init() error { if err := s.SetRootCoord(s.rootCoord); err != nil { panic(err) } - log.Debug("QueryCoord report RootCoord ready") + log.Info("QueryCoord report RootCoord ready") // --- Data service client --- if s.dataCoord == nil { @@ -185,7 +198,7 @@ func (s *Server) init() error { } } - log.Debug("QueryCoord try to wait for DataCoord ready") + log.Info("QueryCoord try to wait for DataCoord ready") err = componentutil.WaitForComponentHealthy(s.loopCtx, s.dataCoord, "DataCoord", 1000000, time.Millisecond*200) if err != nil { log.Error("QueryCoord wait for DataCoord ready failed", zap.Error(err)) @@ -194,7 +207,7 @@ func (s *Server) init() error { if err := s.SetDataCoord(s.dataCoord); err != nil { panic(err) } - log.Debug("QueryCoord report DataCoord ready") + log.Info("QueryCoord report DataCoord ready") if err := s.queryCoord.Init(); err != nil { return err @@ -202,7 +215,7 @@ func (s *Server) init() error { return nil } -func (s *Server) startGrpcLoop(grpcPort int) { +func (s *Server) startGrpcLoop() { defer s.grpcWG.Done() Params := ¶mtable.Get().QueryCoordGrpcServerCfg kaep := keepalive.EnforcementPolicy{ @@ -214,14 +227,6 @@ func (s *Server) startGrpcLoop(grpcPort int) { Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } - log.Debug("network", zap.String("port", strconv.Itoa(grpcPort))) - lis, err := net.Listen("tcp", ":"+strconv.Itoa(grpcPort)) - if err != nil { - log.Debug("GrpcServer:failed to listen:", zap.Error(err)) - s.grpcErrChan <- err - return - } - ctx, cancel := context.WithCancel(s.loopCtx) defer cancel() @@ -255,7 +260,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { querypb.RegisterQueryCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) - if err := s.grpcServer.Serve(lis); err != nil { + if err := s.grpcServer.Serve(s.listener); err != nil { s.grpcErrChan <- err } } @@ -275,8 +280,10 @@ func (s *Server) GetQueryCoord() types.QueryCoordComponent { // Stop stops QueryCoord's grpc service. func (s *Server) Stop() (err error) { - Params := ¶mtable.Get().QueryCoordGrpcServerCfg - logger := log.With(zap.String("address", Params.GetAddress())) + logger := log.With() + if s.listener != nil { + logger = log.With(zap.String("address", s.listener.Address())) + } logger.Info("QueryCoord stopping") defer func() { logger.Info("QueryCoord stopped", zap.Error(err)) @@ -296,6 +303,11 @@ func (s *Server) Stop() (err error) { log.Error("failed to close queryCoord", zap.Error(err)) } s.loopCancel() + + // release port resource + if s.listener != nil { + s.listener.Close() + } return nil } diff --git a/internal/distributed/querycoord/service_test.go b/internal/distributed/querycoord/service_test.go index 08ce7f7d77793..1e80f2bad3916 100644 --- a/internal/distributed/querycoord/service_test.go +++ b/internal/distributed/querycoord/service_test.go @@ -78,6 +78,8 @@ func Test_NewServer(t *testing.T) { server.dataCoord = mdc server.rootCoord = mrc + err = server.Prepare() + assert.NoError(t, err) err = server.Run() assert.NoError(t, err) }) diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index e60120050d7ca..e66884681aa5f 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -18,8 +18,6 @@ package grpcquerynode import ( "context" - "fmt" - "net" "strconv" "sync" "time" @@ -46,8 +44,8 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/netutil" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -65,6 +63,7 @@ type Server struct { serverID atomic.Int64 grpcServer *grpc.Server + listener *netutil.NetListener etcdCli *clientv3.Client } @@ -90,17 +89,27 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) return s, nil } +func (s *Server) Prepare() error { + listener, err := netutil.NewListener( + netutil.OptIP(paramtable.Get().QueryNodeGrpcServerCfg.IP), + netutil.OptHighPriorityToUsePort(paramtable.Get().QueryNodeGrpcServerCfg.Port.GetAsInt()), + ) + if err != nil { + log.Warn("QueryNode fail to create net listener", zap.Error(err)) + return err + } + s.listener = listener + log.Info("QueryNode listen on", zap.String("address", listener.Addr().String()), zap.Int("port", listener.Port())) + paramtable.Get().Save( + paramtable.Get().QueryNodeGrpcServerCfg.Port.Key, + strconv.FormatInt(int64(listener.Port()), 10)) + return nil +} + // init initializes QueryNode's grpc service. func (s *Server) init() error { etcdConfig := ¶mtable.Get().EtcdCfg - Params := ¶mtable.Get().QueryNodeGrpcServerCfg - - if !funcutil.CheckPortAvailable(Params.Port.GetAsInt()) { - paramtable.Get().Save(Params.Port.Key, fmt.Sprintf("%d", funcutil.GetAvailablePort())) - log.Warn("QueryNode get available port when init", zap.Int("Port", Params.Port.GetAsInt())) - } - - log.Debug("QueryNode", zap.Int("port", Params.Port.GetAsInt())) + log.Debug("QueryNode", zap.Int("port", s.listener.Port())) etcdCli, err := etcd.CreateEtcdClient( etcdConfig.UseEmbedEtcd.GetAsBool(), @@ -119,10 +128,10 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.SetEtcdClient(etcdCli) - s.querynode.SetAddress(Params.GetAddress()) + s.querynode.SetAddress(s.listener.Address()) log.Debug("QueryNode connect to etcd successfully") s.grpcWG.Add(1) - go s.startGrpcLoop(Params.Port.GetAsInt()) + go s.startGrpcLoop() // wait for grpc server loop start err = <-s.grpcErrChan if err != nil { @@ -154,7 +163,7 @@ func (s *Server) start() error { } // startGrpcLoop starts the grpc loop of QueryNode component. -func (s *Server) startGrpcLoop(grpcPort int) { +func (s *Server) startGrpcLoop() { defer s.grpcWG.Done() Params := ¶mtable.Get().QueryNodeGrpcServerCfg kaep := keepalive.EnforcementPolicy{ @@ -166,24 +175,6 @@ func (s *Server) startGrpcLoop(grpcPort int) { Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } - var lis net.Listener - var err error - err = retry.Do(s.ctx, func() error { - addr := ":" + strconv.Itoa(grpcPort) - lis, err = net.Listen("tcp", addr) - if err == nil { - s.querynode.SetAddress(fmt.Sprintf("%s:%d", Params.IP, lis.Addr().(*net.TCPAddr).Port)) - } else { - // set port=0 to get next available port - grpcPort = 0 - } - return err - }, retry.Attempts(10)) - if err != nil { - log.Error("QueryNode GrpcServer:failed to listen", zap.Error(err)) - s.grpcErrChan <- err - return - } s.grpcServer = grpc.NewServer( grpc.KeepaliveEnforcementPolicy(kaep), @@ -220,7 +211,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { defer cancel() go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) - if err := s.grpcServer.Serve(lis); err != nil { + if err := s.grpcServer.Serve(s.listener); err != nil { log.Debug("QueryNode Start Grpc Failed!!!!") s.grpcErrChan <- err } @@ -242,8 +233,10 @@ func (s *Server) Run() error { // Stop stops QueryNode's grpc service. func (s *Server) Stop() (err error) { - Params := ¶mtable.Get().QueryNodeGrpcServerCfg - logger := log.With(zap.String("address", Params.GetAddress())) + logger := log.With() + if s.listener != nil { + logger = log.With(zap.String("address", s.listener.Address())) + } logger.Info("QueryNode stopping") defer func() { logger.Info("QueryNode stopped", zap.Error(err)) @@ -265,6 +258,9 @@ func (s *Server) Stop() (err error) { s.grpcWG.Wait() s.cancel() + if s.listener != nil { + s.listener.Close() + } return nil } diff --git a/internal/distributed/querynode/service_test.go b/internal/distributed/querynode/service_test.go index fc979387e64a5..caa69c14e62c5 100644 --- a/internal/distributed/querynode/service_test.go +++ b/internal/distributed/querynode/service_test.go @@ -95,6 +95,8 @@ func Test_NewServer(t *testing.T) { server.querynode = mockQN t.Run("Run", func(t *testing.T) { + err = server.Prepare() + assert.NoError(t, err) err = server.Run() assert.NoError(t, err) }) @@ -288,6 +290,8 @@ func Test_Run(t *testing.T) { mockQN.EXPECT().Init().Return(nil).Maybe() mockQN.EXPECT().GetNodeID().Return(2).Maybe() server.querynode = mockQN + err = server.Prepare() + assert.NoError(t, err) err = server.Run() assert.Error(t, err) diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index 9f50661fe9589..d49c3ae4a89f6 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -18,8 +18,6 @@ package grpcrootcoord import ( "context" - "net" - "strconv" "sync" "time" @@ -50,6 +48,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/netutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tikv" ) @@ -58,6 +57,7 @@ import ( type Server struct { rootCoord types.RootCoordComponent grpcServer *grpc.Server + listener *netutil.NetListener grpcErrChan chan error grpcWG sync.WaitGroup @@ -142,6 +142,20 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) return s, err } +func (s *Server) Prepare() error { + listener, err := netutil.NewListener( + netutil.OptIP(paramtable.Get().RootCoordGrpcServerCfg.IP), + netutil.OptPort(paramtable.Get().RootCoordGrpcServerCfg.Port.GetAsInt()), + ) + if err != nil { + log.Warn("RootCoord fail to create net listener", zap.Error(err)) + return err + } + log.Info("RootCoord listen on", zap.String("address", listener.Addr().String()), zap.Int("port", listener.Port())) + s.listener = listener + return nil +} + func (s *Server) setClient() { s.newDataCoordClient = func() types.DataCoordClient { dsClient, err := dcc.NewClient(s.ctx) @@ -165,12 +179,12 @@ func (s *Server) Run() error { if err := s.init(); err != nil { return err } - log.Debug("RootCoord init done ...") + log.Info("RootCoord init done ...") if err := s.start(); err != nil { return err } - log.Debug("RootCoord start done ...") + log.Info("RootCoord start done ...") return nil } @@ -179,8 +193,7 @@ var getTiKVClient = tikv.GetTiKVClient func (s *Server) init() error { params := paramtable.Get() etcdConfig := ¶ms.EtcdCfg - rpcParams := ¶ms.RootCoordGrpcServerCfg - log.Debug("init params done..") + log.Info("init params done..") etcdCli, err := etcd.CreateEtcdClient( etcdConfig.UseEmbedEtcd.GetAsBool(), @@ -194,33 +207,33 @@ func (s *Server) init() error { etcdConfig.EtcdTLSCACert.GetValue(), etcdConfig.EtcdTLSMinVersion.GetValue()) if err != nil { - log.Debug("RootCoord connect to etcd failed", zap.Error(err)) + log.Warn("RootCoord connect to etcd failed", zap.Error(err)) return err } s.etcdCli = etcdCli s.rootCoord.SetEtcdClient(s.etcdCli) - s.rootCoord.SetAddress(rpcParams.GetAddress()) - log.Debug("etcd connect done ...") + s.rootCoord.SetAddress(s.listener.Address()) + log.Info("etcd connect done ...") if params.MetaStoreCfg.MetaStoreType.GetValue() == util.MetaStoreTypeTiKV { log.Info("Connecting to tikv metadata storage.") s.tikvCli, err = getTiKVClient(¶mtable.Get().TiKVCfg) if err != nil { - log.Debug("RootCoord failed to connect to tikv", zap.Error(err)) + log.Warn("RootCoord failed to connect to tikv", zap.Error(err)) return err } s.rootCoord.SetTiKVClient(s.tikvCli) log.Info("Connected to tikv. Using tikv as metadata storage.") } - err = s.startGrpc(rpcParams.Port.GetAsInt()) + err = s.startGrpc() if err != nil { return err } - log.Debug("grpc init done ...") + log.Info("grpc init done ...") if s.newDataCoordClient != nil { - log.Debug("RootCoord start to create DataCoord client") + log.Info("RootCoord start to create DataCoord client") dataCoord := s.newDataCoordClient() s.dataCoord = dataCoord if err := s.rootCoord.SetDataCoordClient(dataCoord); err != nil { @@ -229,7 +242,7 @@ func (s *Server) init() error { } if s.newQueryCoordClient != nil { - log.Debug("RootCoord start to create QueryCoord client") + log.Info("RootCoord start to create QueryCoord client") queryCoord := s.newQueryCoordClient() s.queryCoord = queryCoord if err := s.rootCoord.SetQueryCoordClient(queryCoord); err != nil { @@ -240,15 +253,15 @@ func (s *Server) init() error { return s.rootCoord.Init() } -func (s *Server) startGrpc(port int) error { +func (s *Server) startGrpc() error { s.grpcWG.Add(1) - go s.startGrpcLoop(port) + go s.startGrpcLoop() // wait for grpc server loop start err := <-s.grpcErrChan return err } -func (s *Server) startGrpcLoop(port int) { +func (s *Server) startGrpcLoop() { defer s.grpcWG.Done() Params := ¶mtable.Get().RootCoordGrpcServerCfg kaep := keepalive.EnforcementPolicy{ @@ -260,13 +273,7 @@ func (s *Server) startGrpcLoop(port int) { Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } - log.Debug("start grpc ", zap.Int("port", port)) - lis, err := net.Listen("tcp", ":"+strconv.Itoa(port)) - if err != nil { - log.Error("GrpcServer:failed to listen", zap.Error(err)) - s.grpcErrChan <- err - return - } + log.Info("start grpc ", zap.Int("port", s.listener.Port())) ctx, cancel := context.WithCancel(s.ctx) defer cancel() @@ -300,7 +307,7 @@ func (s *Server) startGrpcLoop(port int) { rootcoordpb.RegisterRootCoordServer(s.grpcServer, s) go funcutil.CheckGrpcReady(ctx, s.grpcErrChan) - if err := s.grpcServer.Serve(lis); err != nil { + if err := s.grpcServer.Serve(s.listener); err != nil { s.grpcErrChan <- err } } @@ -321,8 +328,10 @@ func (s *Server) start() error { } func (s *Server) Stop() (err error) { - Params := ¶mtable.Get().RootCoordGrpcServerCfg - logger := log.With(zap.String("address", Params.GetAddress())) + logger := log.With() + if s.listener != nil { + logger = log.With(zap.String("address", s.listener.Address())) + } logger.Info("Rootcoord stopping") defer func() { logger.Info("Rootcoord stopped", zap.Error(err)) @@ -358,6 +367,9 @@ func (s *Server) Stop() (err error) { } s.cancel() + if s.listener != nil { + s.listener.Close() + } return nil } diff --git a/internal/distributed/rootcoord/service_test.go b/internal/distributed/rootcoord/service_test.go index e758dbf2bb165..43c9ba6ba0534 100644 --- a/internal/distributed/rootcoord/service_test.go +++ b/internal/distributed/rootcoord/service_test.go @@ -118,24 +118,27 @@ func TestRun(t *testing.T) { parameters := []string{"tikv", "etcd"} for _, v := range parameters { paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) - ctx, cancel := context.WithCancel(context.Background()) + ctx := context.Background() getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { return tikv.SetupLocalTxn(), nil } defer func() { getTiKVClient = tikv.GetTiKVClient }() - svr := Server{ - rootCoord: &mockCore{}, - ctx: ctx, - cancel: cancel, - grpcErrChan: make(chan error), - } rcServerConfig := ¶mtable.Get().RootCoordGrpcServerCfg + oldPort := rcServerConfig.Port.GetValue() paramtable.Get().Save(rcServerConfig.Port.Key, "1000000") - err := svr.Run() + svr, err := NewServer(ctx, nil) + assert.NoError(t, err) + err = svr.Prepare() assert.Error(t, err) assert.EqualError(t, err, "listen tcp: address 1000000: invalid port") + paramtable.Get().Save(rcServerConfig.Port.Key, oldPort) + + svr, err = NewServer(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, svr) + svr.rootCoord = &mockCore{} mockDataCoord := mocks.NewMockDataCoordClient(t) mockDataCoord.EXPECT().Close().Return(nil) @@ -172,6 +175,8 @@ func TestRun(t *testing.T) { sessKey := path.Join(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot) _, err = etcdCli.Delete(ctx, sessKey, clientv3.WithPrefix()) assert.NoError(t, err) + err = svr.Prepare() + assert.NoError(t, err) err = svr.Run() assert.NoError(t, err) @@ -236,6 +241,8 @@ func TestServerRun_DataCoordClientInitErr(t *testing.T) { server.newDataCoordClient = func() types.DataCoordClient { return mockDataCoord } + err = server.Prepare() + assert.NoError(t, err) assert.Panics(t, func() { server.Run() }) err = server.Stop() @@ -264,6 +271,8 @@ func TestServerRun_DataCoordClientStartErr(t *testing.T) { server.newDataCoordClient = func() types.DataCoordClient { return mockDataCoord } + err = server.Prepare() + assert.NoError(t, err) assert.Panics(t, func() { server.Run() }) err = server.Stop() @@ -292,7 +301,8 @@ func TestServerRun_QueryCoordClientInitErr(t *testing.T) { server.newQueryCoordClient = func() types.QueryCoordClient { return mockQueryCoord } - + err = server.Prepare() + assert.NoError(t, err) assert.Panics(t, func() { server.Run() }) err = server.Stop() @@ -321,6 +331,8 @@ func TestServer_QueryCoordClientStartErr(t *testing.T) { server.newQueryCoordClient = func() types.QueryCoordClient { return mockQueryCoord } + err = server.Prepare() + assert.NoError(t, err) assert.Panics(t, func() { server.Run() }) err = server.Stop() diff --git a/internal/distributed/streamingnode/service.go b/internal/distributed/streamingnode/service.go index 3d5773e506441..3f17c1aeb4c9e 100644 --- a/internal/distributed/streamingnode/service.go +++ b/internal/distributed/streamingnode/service.go @@ -18,8 +18,6 @@ package streamingnode import ( "context" - "fmt" - "net" "os" "strconv" "sync" @@ -55,8 +53,8 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/netutil" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/tikv" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -75,7 +73,7 @@ type Server struct { // rpc grpcServer *grpc.Server - lis net.Listener + listener *netutil.NetListener factory dependency.Factory @@ -98,6 +96,23 @@ func NewServer(f dependency.Factory) (*Server, error) { }, nil } +func (s *Server) Prepare() error { + listener, err := netutil.NewListener( + netutil.OptIP(paramtable.Get().StreamingNodeGrpcServerCfg.IP), + netutil.OptHighPriorityToUsePort(paramtable.Get().StreamingNodeGrpcServerCfg.Port.GetAsInt()), + ) + if err != nil { + log.Warn("StreamingNode fail to create net listener", zap.Error(err)) + return err + } + s.listener = listener + log.Info("StreamingNode listen on", zap.String("address", listener.Addr().String()), zap.Int("port", listener.Port())) + paramtable.Get().Save( + paramtable.Get().StreamingNodeGrpcServerCfg.Port.Key, + strconv.FormatInt(int64(listener.Port()), 10)) + return nil +} + // Run runs the server. func (s *Server) Run() error { // TODO: We should set a timeout for the process startup. @@ -126,8 +141,7 @@ func (s *Server) Stop() (err error) { func (s *Server) stop() { s.componentState.OnStopping() - addr, _ := s.getAddress() - log.Info("streamingnode stop", zap.String("Address", addr)) + log.Info("streamingnode stop", zap.String("Address", s.listener.Address())) // Unregister current server from etcd. log.Info("streamingnode unregister session from etcd...") @@ -164,6 +178,10 @@ func (s *Server) stop() { log.Info("wait for grpc server stop...") <-s.grpcServerChan log.Info("streamingnode stop done") + + if err := s.listener.Close(); err != nil { + log.Warn("streamingnode stop listener failed", zap.Error(err)) + } } // Health check the health status of streamingnode. @@ -190,9 +208,6 @@ func (s *Server) init(ctx context.Context) (err error) { if err := s.initChunkManager(ctx); err != nil { return err } - if err := s.allocateAddress(); err != nil { - return err - } if err := s.initSession(ctx); err != nil { return err } @@ -249,13 +264,9 @@ func (s *Server) initSession(ctx context.Context) error { if s.session == nil { return errors.New("session is nil, the etcd client connection may have failed") } - addr, err := s.getAddress() - if err != nil { - return err - } - s.session.Init(typeutil.StreamingNodeRole, addr, false, true) + s.session.Init(typeutil.StreamingNodeRole, s.listener.Address(), false, true) paramtable.SetNodeID(s.session.ServerID) - log.Info("StreamingNode init session", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("node address", addr)) + log.Info("StreamingNode init session", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("node address", s.listener.Address())) return nil } @@ -360,42 +371,13 @@ func (s *Server) initGRPCServer() { streamingpb.RegisterStreamingNodeStateServiceServer(s.grpcServer, s.componentState) } -// allocateAddress allocates a available address for streamingnode grpc server. -func (s *Server) allocateAddress() (err error) { - port := paramtable.Get().StreamingNodeGrpcServerCfg.Port.GetAsInt() - - retry.Do(context.Background(), func() error { - addr := ":" + strconv.Itoa(port) - s.lis, err = net.Listen("tcp", addr) - if err != nil { - if port != 0 { - // set port=0 to get next available port by os - log.Warn("StreamingNode suggested port is in used, try to get by os", zap.Error(err)) - port = 0 - } - } - return err - }, retry.Attempts(10)) - return err -} - -// getAddress returns the address of streamingnode grpc server. -// must be called after allocateAddress. -func (s *Server) getAddress() (string, error) { - if s.lis == nil { - return "", errors.New("StreamingNode grpc server is not initialized") - } - ip := paramtable.Get().StreamingNodeGrpcServerCfg.IP - return fmt.Sprintf("%s:%d", ip, s.lis.Addr().(*net.TCPAddr).Port), nil -} - // startGRPCServer starts the grpc server. func (s *Server) startGPRCServer(ctx context.Context) error { errCh := make(chan error, 1) go func() { defer close(s.grpcServerChan) - if err := s.grpcServer.Serve(s.lis); err != nil { + if err := s.grpcServer.Serve(s.listener); err != nil { select { case errCh <- err: // failure at initial startup. diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 3da16038e5913..c500bff6461f8 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -96,13 +96,16 @@ func runRootCoord(ctx context.Context, localMsg bool) *grpcrootcoord.Server { wg.Add(1) go func() { + defer wg.Done() factory := dependency.NewDefaultFactory(localMsg) var err error rc, err = grpcrootcoord.NewServer(ctx, factory) if err != nil { panic(err) } - wg.Done() + if err = rc.Prepare(); err != nil { + panic(err) + } err = rc.Run() if err != nil { panic(err) @@ -120,13 +123,16 @@ func runQueryCoord(ctx context.Context, localMsg bool) *grpcquerycoord.Server { wg.Add(1) go func() { + defer wg.Done() factory := dependency.NewDefaultFactory(localMsg) var err error qs, err = grpcquerycoord.NewServer(ctx, factory) if err != nil { panic(err) } - wg.Done() + if err = qs.Prepare(); err != nil { + panic(err) + } err = qs.Run() if err != nil { panic(err) @@ -144,13 +150,16 @@ func runQueryNode(ctx context.Context, localMsg bool, alias string) *grpcqueryno wg.Add(1) go func() { + defer wg.Done() factory := dependency.MockDefaultFactory(localMsg, Params) var err error qn, err = grpcquerynode.NewServer(ctx, factory) if err != nil { panic(err) } - wg.Done() + if err = qn.Prepare(); err != nil { + panic(err) + } err = qn.Run() if err != nil { panic(err) @@ -168,10 +177,17 @@ func runDataCoord(ctx context.Context, localMsg bool) *grpcdatacoordclient.Serve wg.Add(1) go func() { + defer wg.Done() factory := dependency.NewDefaultFactory(localMsg) - ds = grpcdatacoordclient.NewServer(ctx, factory) - wg.Done() - err := ds.Run() + var err error + ds, err = grpcdatacoordclient.NewServer(ctx, factory) + if err != nil { + panic(err) + } + if err = ds.Prepare(); err != nil { + panic(err) + } + err = ds.Run() if err != nil { panic(err) } @@ -188,13 +204,16 @@ func runDataNode(ctx context.Context, localMsg bool, alias string) *grpcdatanode wg.Add(1) go func() { + defer wg.Done() factory := dependency.MockDefaultFactory(localMsg, Params) var err error dn, err = grpcdatanode.NewServer(ctx, factory) if err != nil { panic(err) } - wg.Done() + if err = dn.Prepare(); err != nil { + panic(err) + } err = dn.Run() if err != nil { panic(err) @@ -212,13 +231,13 @@ func runIndexNode(ctx context.Context, localMsg bool, alias string) *grpcindexno wg.Add(1) go func() { + defer wg.Done() factory := dependency.MockDefaultFactory(localMsg, Params) var err error in, err = grpcindexnode.NewServer(ctx, factory) if err != nil { panic(err) } - wg.Done() etcd, err := etcd.GetEtcdClient( Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), Params.EtcdCfg.EtcdUseSSL.GetAsBool(), @@ -231,6 +250,9 @@ func runIndexNode(ctx context.Context, localMsg bool, alias string) *grpcindexno panic(err) } in.SetEtcdClient(etcd) + if err = in.Prepare(); err != nil { + panic(err) + } err = in.Run() if err != nil { panic(err) diff --git a/pkg/util/netutil/listener.go b/pkg/util/netutil/listener.go new file mode 100644 index 0000000000000..19b19202efbae --- /dev/null +++ b/pkg/util/netutil/listener.go @@ -0,0 +1,147 @@ +package netutil + +import ( + "crypto/tls" + "fmt" + "net" + + "github.com/milvus-io/milvus/pkg/util/funcutil" +) + +// NewListener creates a new listener that listens on the specified network and IP address. +func NewListener(opts ...Opt) (*NetListener, error) { + config := getNetListenerConfig(opts...) + if config.tlsConfig != nil { + return newTLSListener(config.tlsConfig, opts...) + } + + // Use the highPriorityToUsePort if it is set. + if config.highPriorityToUsePort != 0 { + if lis, err := net.Listen(config.net, fmt.Sprintf(":%d", config.highPriorityToUsePort)); err == nil { + return &NetListener{ + Listener: lis, + port: config.highPriorityToUsePort, + address: fmt.Sprintf("%s:%d", config.ip, config.highPriorityToUsePort), + }, nil + } + } + // Otherwise use the port number specified by the user. + lis, err := net.Listen(config.net, fmt.Sprintf(":%d", config.port)) + if err != nil { + return nil, err + } + return &NetListener{ + Listener: lis, + port: lis.Addr().(*net.TCPAddr).Port, + address: fmt.Sprintf("%s:%d", config.ip, lis.Addr().(*net.TCPAddr).Port), + }, nil +} + +// newTLSListener creates a new listener that listens on the specified network and IP address with TLS. +func newTLSListener(c *tls.Config, opts ...Opt) (*NetListener, error) { + config := getNetListenerConfig(opts...) + // Use the highPriorityToUsePort if it is set. + if config.highPriorityToUsePort != 0 { + if lis, err := tls.Listen(config.net, fmt.Sprintf(":%d", config.highPriorityToUsePort), c); err == nil { + return &NetListener{ + Listener: lis, + port: config.highPriorityToUsePort, + address: fmt.Sprintf("%s:%d", config.ip, config.highPriorityToUsePort), + }, nil + } + } + // Otherwise use the port number specified by the user. + lis, err := tls.Listen(config.net, fmt.Sprintf(":%d", config.port), c) + if err != nil { + return nil, err + } + return &NetListener{ + Listener: lis, + port: lis.Addr().(*net.TCPAddr).Port, + address: fmt.Sprintf("%s:%d", config.ip, lis.Addr().(*net.TCPAddr).Port), + }, nil +} + +// NetListener is a wrapper around a net.Listener that provides additional functionality. +type NetListener struct { + net.Listener + port int + address string +} + +// Port returns the port that the listener is listening on. +func (nl *NetListener) Port() int { + return nl.port +} + +// Address returns the address that the listener is listening on. +func (nl *NetListener) Address() string { + return nl.address +} + +// netListenerConfig contains the configuration for a NetListener. +type netListenerConfig struct { + net string + ip string + highPriorityToUsePort int + port int + tlsConfig *tls.Config +} + +// getNetListenerConfig returns a netListenerConfig with the default values. +func getNetListenerConfig(opts ...Opt) *netListenerConfig { + defaultConfig := &netListenerConfig{ + net: "tcp", + ip: funcutil.GetLocalIP(), + highPriorityToUsePort: 0, + port: 0, + } + for _, opt := range opts { + opt(defaultConfig) + } + return defaultConfig +} + +// Opt is a function that configures a netListenerConfig. +type Opt func(*netListenerConfig) + +// OptNet sets the network type for the listener. +func OptNet(net string) Opt { + return func(nlc *netListenerConfig) { + nlc.net = net + } +} + +// OptIP sets the IP address for the listener. +func OptIP(ip string) Opt { + return func(nlc *netListenerConfig) { + nlc.ip = ip + } +} + +// OptHighPriorityToUsePort sets the port number to use for the listener. +func OptHighPriorityToUsePort(port int) Opt { + return func(nlc *netListenerConfig) { + if nlc.port != 0 { + panic("OptHighPriorityToUsePort and OptPort are mutually exclusive") + } + nlc.highPriorityToUsePort = port + } +} + +// OptPort sets the port number to use for the listener. +func OptPort(port int) Opt { + return func(nlc *netListenerConfig) { + if nlc.highPriorityToUsePort != 0 { + panic("OptHighPriorityToUsePort and OptPort are mutually exclusive") + } + nlc.port = port + } +} + +// OptTLS sets the TLS configuration for the listener. +func OptTLS(c *tls.Config) Opt { + return func(nlc *netListenerConfig) { + nlc.tlsConfig = c + } +} diff --git a/pkg/util/netutil/listener_test.go b/pkg/util/netutil/listener_test.go new file mode 100644 index 0000000000000..f84f7398edaab --- /dev/null +++ b/pkg/util/netutil/listener_test.go @@ -0,0 +1,38 @@ +package netutil + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestListener(t *testing.T) { + l, err := NewListener( + OptIP("127.0.0.1"), + OptPort(0), + ) + assert.NoError(t, err) + assert.NotNil(t, l) + assert.NotZero(t, l.Port()) + assert.Equal(t, l.Address(), fmt.Sprintf("127.0.0.1:%d", l.Port())) + + l2, err := NewListener( + OptIP("127.0.0.1"), + OptPort(l.Port()), + ) + assert.Error(t, err) + assert.Nil(t, l2) + + l3, err := NewListener( + OptIP("127.0.0.1"), + OptHighPriorityToUsePort(l.Port()), + ) + assert.NoError(t, err) + assert.NotNil(t, l3) + assert.NotZero(t, l3.Port()) + assert.Equal(t, l3.Address(), fmt.Sprintf("127.0.0.1:%d", l3.Port())) + + l3.Close() + l.Close() +} diff --git a/tests/integration/coorddownsearch/search_after_coord_down_test.go b/tests/integration/coorddownsearch/search_after_coord_down_test.go index 0a0df039dc34e..1b02bd0a6e894 100644 --- a/tests/integration/coorddownsearch/search_after_coord_down_test.go +++ b/tests/integration/coorddownsearch/search_after_coord_down_test.go @@ -30,9 +30,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - grpcdatacoord "github.com/milvus-io/milvus/internal/distributed/datacoord" - grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord" - grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -277,7 +274,6 @@ func (s *CoordDownSearch) setupData() { } func (s *CoordDownSearch) searchAfterCoordDown() float64 { - var err error c := s.Cluster params := paramtable.Get() @@ -285,19 +281,19 @@ func (s *CoordDownSearch) searchAfterCoordDown() float64 { start := time.Now() log.Info("=========================Data Coordinators stopped=========================") - c.DataCoord.Stop() + c.StopDataCoord() s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Eventually) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Bounded) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Strong) log.Info("=========================Query Coordinators stopped=========================") - c.QueryCoord.Stop() + c.StopQueryCoord() s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Eventually) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Bounded) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Strong) log.Info("=========================Root Coordinators stopped=========================") - c.RootCoord.Stop() + c.StopRootCoord() params.Save(params.CommonCfg.GracefulTime.Key, "60000") s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Bounded) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Eventually) @@ -307,28 +303,19 @@ func (s *CoordDownSearch) searchAfterCoordDown() float64 { log.Info(fmt.Sprintf("=========================Failed search cost: %fs=========================", time.Since(failedStart).Seconds())) log.Info("=========================restart Root Coordinators=========================") - c.RootCoord, err = grpcrootcoord.NewServer(context.TODO(), c.GetFactory()) - s.NoError(err) - err = c.RootCoord.Run() - s.NoError(err) + c.StartRootCoord() s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Eventually) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Bounded) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Strong) log.Info("=========================restart Data Coordinators=========================") - c.DataCoord = grpcdatacoord.NewServer(context.TODO(), c.GetFactory()) - s.NoError(err) - err = c.DataCoord.Run() - s.NoError(err) + c.StartDataCoord() s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Eventually) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Bounded) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Strong) log.Info("=========================restart Query Coordinators=========================") - c.QueryCoord, err = grpcquerycoord.NewServer(context.TODO(), c.GetFactory()) - s.NoError(err) - err = c.QueryCoord.Run() - s.NoError(err) + c.StartQueryCoord() s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Eventually) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Bounded) s.search(searchCollectionName, Dim, commonpb.ConsistencyLevel_Strong) diff --git a/tests/integration/coordrecovery/coord_recovery_test.go b/tests/integration/coordrecovery/coord_recovery_test.go index 2ca2a455fc33f..46f02c593bbe6 100644 --- a/tests/integration/coordrecovery/coord_recovery_test.go +++ b/tests/integration/coordrecovery/coord_recovery_test.go @@ -30,9 +30,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - grpcdatacoord "github.com/milvus-io/milvus/internal/distributed/datacoord" - grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord" - grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -236,31 +233,20 @@ func (s *CoordSwitchSuite) setupData() { } func (s *CoordSwitchSuite) switchCoord() float64 { - var err error c := s.Cluster start := time.Now() log.Info("=========================Stopping Coordinators========================") - c.RootCoord.Stop() - c.DataCoord.Stop() - c.QueryCoord.Stop() + c.StopRootCoord() + c.StopDataCoord() + c.StopQueryCoord() log.Info("=========================Coordinators stopped=========================", zap.Duration("elapsed", time.Since(start))) start = time.Now() - c.RootCoord, err = grpcrootcoord.NewServer(context.TODO(), c.GetFactory()) - s.NoError(err) - c.DataCoord = grpcdatacoord.NewServer(context.TODO(), c.GetFactory()) - c.QueryCoord, err = grpcquerycoord.NewServer(context.TODO(), c.GetFactory()) - s.NoError(err) - log.Info("=========================Coordinators recreated=========================") - - err = c.RootCoord.Run() - s.NoError(err) + c.StartRootCoord() log.Info("=========================RootCoord restarted=========================") - err = c.DataCoord.Run() - s.NoError(err) + c.StartDataCoord() log.Info("=========================DataCoord restarted=========================") - err = c.QueryCoord.Run() - s.NoError(err) + c.StartQueryCoord() log.Info("=========================QueryCoord restarted=========================") for i := 0; i < 1000; i++ { diff --git a/tests/integration/datanode/compaction_test.go b/tests/integration/datanode/compaction_test.go index 8cd7e77c6a8c5..b710977a0afe9 100644 --- a/tests/integration/datanode/compaction_test.go +++ b/tests/integration/datanode/compaction_test.go @@ -14,7 +14,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - grpcdatacoord "github.com/milvus-io/milvus/internal/distributed/datacoord" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" @@ -152,10 +151,8 @@ func (s *CompactionSuite) compactAndReboot(collection string) { // Reboot if planResp.GetMergeInfos()[0].GetTarget() == int64(-1) { - s.Cluster.DataCoord.Stop() - s.Cluster.DataCoord = grpcdatacoord.NewServer(ctx, s.Cluster.GetFactory()) - err = s.Cluster.DataCoord.Run() - s.Require().NoError(err) + s.Cluster.StopDataCoord() + s.Cluster.StartDataCoord() stateResp, err = s.Cluster.Proxy.GetCompactionState(ctx, &milvuspb.GetCompactionStateRequest{ CompactionID: compactID, diff --git a/tests/integration/minicluster_v2.go b/tests/integration/minicluster_v2.go index 37cf1459b48f8..762b2dc28b4e4 100644 --- a/tests/integration/minicluster_v2.go +++ b/tests/integration/minicluster_v2.go @@ -232,7 +232,10 @@ func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2, if err != nil { return nil, err } - cluster.DataCoord = grpcdatacoord.NewServer(ctx, cluster.factory) + cluster.DataCoord, err = grpcdatacoord.NewServer(ctx, cluster.factory) + if err != nil { + return nil, err + } cluster.QueryCoord, err = grpcquerycoord.NewServer(ctx, cluster.factory) if err != nil { return nil, err @@ -282,10 +285,7 @@ func (cluster *MiniClusterV2) AddQueryNode() *grpcquerynode.Server { if err != nil { return nil } - err = node.Run() - if err != nil { - return nil - } + runComponent(node) paramtable.SetNodeID(oid) req := &milvuspb.GetComponentStatesRequest{} @@ -310,10 +310,7 @@ func (cluster *MiniClusterV2) AddDataNode() *grpcdatanode.Server { if err != nil { return nil } - err = node.Run() - if err != nil { - return nil - } + runComponent(node) paramtable.SetNodeID(oid) req := &milvuspb.GetComponentStatesRequest{} @@ -334,50 +331,19 @@ func (cluster *MiniClusterV2) AddStreamingNode() { if err != nil { panic(err) } - err = node.Run() - if err != nil { - panic(err) - } - + runComponent(node) cluster.streamingnodes = append(cluster.streamingnodes, node) } func (cluster *MiniClusterV2) Start() error { log.Info("mini cluster start") - err := cluster.RootCoord.Run() - if err != nil { - return err - } - - err = cluster.DataCoord.Run() - if err != nil { - return err - } - - err = cluster.QueryCoord.Run() - if err != nil { - return err - } - - err = cluster.DataNode.Run() - if err != nil { - return err - } - - err = cluster.QueryNode.Run() - if err != nil { - return err - } - - err = cluster.IndexNode.Run() - if err != nil { - return err - } - - err = cluster.Proxy.Run() - if err != nil { - return err - } + runComponent(cluster.RootCoord) + runComponent(cluster.DataCoord) + runComponent(cluster.QueryCoord) + runComponent(cluster.Proxy) + runComponent(cluster.DataNode) + runComponent(cluster.QueryNode) + runComponent(cluster.IndexNode) ctx2, cancel := context.WithTimeout(context.Background(), time.Second*120) defer cancel() @@ -392,13 +358,11 @@ func (cluster *MiniClusterV2) Start() error { } if streamingutil.IsStreamingServiceEnabled() { - err = cluster.StreamingNode.Run() - if err != nil { - return err - } + runComponent(cluster.StreamingNode) } port := params.ProxyGrpcServerCfg.Port.GetAsInt() + var err error cluster.clientConn, err = grpc.DialContext(cluster.ctx, fmt.Sprintf("localhost:%d", port), getGrpcDialOpt()...) if err != nil { return err @@ -409,6 +373,57 @@ func (cluster *MiniClusterV2) Start() error { return nil } +func (cluster *MiniClusterV2) StopRootCoord() { + if err := cluster.RootCoord.Stop(); err != nil { + panic(err) + } + cluster.RootCoord = nil +} + +func (cluster *MiniClusterV2) StartRootCoord() { + if cluster.RootCoord == nil { + var err error + if cluster.RootCoord, err = grpcrootcoord.NewServer(cluster.ctx, cluster.factory); err != nil { + panic(err) + } + runComponent(cluster.RootCoord) + } +} + +func (cluster *MiniClusterV2) StopDataCoord() { + if err := cluster.DataCoord.Stop(); err != nil { + panic(err) + } + cluster.DataCoord = nil +} + +func (cluster *MiniClusterV2) StartDataCoord() { + if cluster.DataCoord == nil { + var err error + if cluster.DataCoord, err = grpcdatacoord.NewServer(cluster.ctx, cluster.factory); err != nil { + panic(err) + } + runComponent(cluster.DataCoord) + } +} + +func (cluster *MiniClusterV2) StopQueryCoord() { + if err := cluster.QueryCoord.Stop(); err != nil { + panic(err) + } + cluster.QueryCoord = nil +} + +func (cluster *MiniClusterV2) StartQueryCoord() { + if cluster.QueryCoord == nil { + var err error + if cluster.QueryCoord, err = grpcquerycoord.NewServer(cluster.ctx, cluster.factory); err != nil { + panic(err) + } + runComponent(cluster.QueryCoord) + } +} + func getGrpcDialOpt() []grpc.DialOption { return []grpc.DialOption{ grpc.WithBlock(), @@ -581,3 +596,17 @@ func (r *ReportChanExtension) Report(info any) int { func (r *ReportChanExtension) GetReportChan() <-chan any { return r.reportChan } + +type component interface { + Prepare() error + Run() error +} + +func runComponent(c component) { + if err := c.Prepare(); err != nil { + panic(err) + } + if err := c.Run(); err != nil { + panic(err) + } +} diff --git a/tests/integration/target/target_test.go b/tests/integration/target/target_test.go index f2c2f22446812..3434a6f2c2084 100644 --- a/tests/integration/target/target_test.go +++ b/tests/integration/target/target_test.go @@ -30,7 +30,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" @@ -158,7 +157,7 @@ func (s *TargetTestSuit) TestQueryCoordRestart() { collectionID := info.GetCollectionID() // trigger old coord stop - s.Cluster.QueryCoord.Stop() + s.Cluster.StopQueryCoord() // keep insert, make segment list change every 3 seconds closeInsertCh := make(chan struct{}) @@ -186,17 +185,11 @@ func (s *TargetTestSuit) TestQueryCoordRestart() { paramtable.Get().Save(paramtable.Get().QueryCoordGrpcServerCfg.Port.Key, fmt.Sprint(port)) // start a new QC - newQC, err := grpcquerycoord.NewServer(ctx, s.Cluster.GetFactory()) - s.NoError(err) - go func() { - err := newQC.Run() - s.NoError(err) - }() - s.Cluster.QueryCoord = newQC + s.Cluster.StartQueryCoord() // after new QC become Active, expected the new target is ready immediately, and get shard leader success s.Eventually(func() bool { - resp, err := newQC.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) + resp, err := s.Cluster.QueryCoord.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) s.NoError(err) if resp.IsHealthy { resp, err := s.Cluster.QueryCoord.GetShardLeaders(ctx, &querypb.GetShardLeadersRequest{