diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 25e61dd12d38e..449e55538399f 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -44,6 +44,7 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -236,7 +237,7 @@ func defaultIndexNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) } func defaultRootCoordCreatorFunc(ctx context.Context, metaRootPath string, client *clientv3.Client) (types.RootCoordClient, error) { - return rootcoordclient.NewClient(ctx, metaRootPath, client) + return rootcoordclient.NewClient(ctx, grpcclient.NewRawEntryProvider(client, metaRootPath, typeutil.RootCoordRole)) } // QuitSignal returns signal when server quits diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index 97c3bfd908a13..99cf60b2f2bb7 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -18,9 +18,7 @@ package grpcdatacoordclient import ( "context" - "fmt" - clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" "google.golang.org/grpc" @@ -31,7 +29,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/grpcclient" - "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -46,28 +43,20 @@ var _ types.DataCoordClient = (*Client)(nil) // Client is the datacoord grpc client type Client struct { grpcClient grpcclient.GrpcClient[datapb.DataCoordClient] - sess *sessionutil.Session + sp grpcclient.ServiceProvider sourceID int64 } // NewClient creates a new client instance -func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) (*Client, error) { - sess := sessionutil.NewSession(ctx, metaRoot, etcdCli) - if sess == nil { - err := fmt.Errorf("new session error, maybe can not connect to etcd") - log.Debug("DataCoordClient NewClient failed", zap.Error(err)) - return nil, err - } - +func NewClient(ctx context.Context, sp grpcclient.ServiceProvider) (*Client, error) { config := &Params.DataCoordGrpcClientCfg client := &Client{ grpcClient: grpcclient.NewClientBase[datapb.DataCoordClient](config, "milvus.proto.data.DataCoord"), - sess: sess, + sp: sp, } client.grpcClient.SetRole(typeutil.DataCoordRole) client.grpcClient.SetGetAddrFunc(client.getDataCoordAddr) client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) - client.grpcClient.SetSession(sess) return client, nil } @@ -76,21 +65,14 @@ func (c *Client) newGrpcClient(cc *grpc.ClientConn) datapb.DataCoordClient { return datapb.NewDataCoordClient(cc) } -func (c *Client) getDataCoordAddr() (string, error) { - key := c.grpcClient.GetRole() - msess, _, err := c.sess.GetSessions(key) +func (c *Client) getDataCoordAddr(ctx context.Context) (string, error) { + addr, serverID, err := c.sp.GetServiceEntry(ctx) if err != nil { - log.Debug("DataCoordClient, getSessions failed", zap.Any("key", key), zap.Error(err)) + log.Warn("DataCoordClient get service entry failed", zap.Error(err)) return "", err } - ms, ok := msess[key] - if !ok { - log.Debug("DataCoordClient, not existed in msess ", zap.Any("key", key), zap.Any("len of msess", len(msess))) - return "", fmt.Errorf("find no available datacoord, check datacoord state") - } - - c.grpcClient.SetNodeID(ms.ServerID) - return ms.Address, nil + c.grpcClient.SetNodeID(serverID) + return addr, nil } // Stop stops the client @@ -329,7 +311,7 @@ func (c *Client) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryI req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), ) return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*datapb.GetRecoveryInfoResponseV2, error) { return client.GetRecoveryInfoV2(ctx, req) diff --git a/internal/distributed/datacoord/client/client_test.go b/internal/distributed/datacoord/client/client_test.go index ab4b78590a862..061c4c68968ef 100644 --- a/internal/distributed/datacoord/client/client_test.go +++ b/internal/distributed/datacoord/client/client_test.go @@ -25,16 +25,18 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/mock" "go.uber.org/zap" - "google.golang.org/grpc" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proxy" - "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" ) func TestMain(m *testing.M) { @@ -66,7 +68,7 @@ func Test_NewClient(t *testing.T) { Params.EtcdCfg.EtcdTLSCACert.GetValue(), Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) assert.NoError(t, err) - client, err := NewClient(ctx, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) + client, err := NewClient(ctx, grpcclient.NewRawEntryProvider(etcdCli, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.DataCoordRole)) assert.NoError(t, err) assert.NotNil(t, client) diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index 9db4d3a7658ef..838610d7fc5d8 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -70,7 +70,7 @@ func (c *Client) newGrpcClient(cc *grpc.ClientConn) datapb.DataNodeClient { return datapb.NewDataNodeClient(cc) } -func (c *Client) getAddr() (string, error) { +func (c *Client) getAddr(_ context.Context) (string, error) { return c.addr, nil } diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 3d3dba2537d70..8f2b0da8612d3 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -24,6 +24,12 @@ import ( "sync" "time" + "github.com/milvus-io/milvus/internal/util/componentutil" + "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util/interceptor" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -41,17 +47,14 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/internal/util/componentutil" - "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type Server struct { @@ -75,17 +78,17 @@ type Server struct { // NewServer new DataNode grpc server func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) { - ctx1, cancel := context.WithCancel(ctx) - s := &Server{ - ctx: ctx1, + ctx, cancel := context.WithCancel(ctx) + var s = &Server{ + ctx: ctx, cancel: cancel, factory: factory, grpcErrChan: make(chan error), newRootCoordClient: func(etcdMetaRoot string, client *clientv3.Client) (types.RootCoordClient, error) { - return rcc.NewClient(ctx1, etcdMetaRoot, client) + return rcc.NewClient(ctx, grpcclient.NewRawEntryProvider(client, etcdMetaRoot, typeutil.RootCoordRole)) }, newDataCoordClient: func(etcdMetaRoot string, client *clientv3.Client) (types.DataCoordClient, error) { - return dcc.NewClient(ctx1, etcdMetaRoot, client) + return dcc.NewClient(ctx, grpcclient.NewRawEntryProvider(client, etcdMetaRoot, typeutil.DataCoordRole)) }, } diff --git a/internal/distributed/indexnode/client/client.go b/internal/distributed/indexnode/client/client.go index af45015b7f597..e5150768b30be 100644 --- a/internal/distributed/indexnode/client/client.go +++ b/internal/distributed/indexnode/client/client.go @@ -71,7 +71,7 @@ func (c *Client) newGrpcClient(cc *grpc.ClientConn) indexpb.IndexNodeClient { return indexpb.NewIndexNodeClient(cc) } -func (c *Client) getAddr() (string, error) { +func (c *Client) getAddr(_ context.Context) (string, error) { return c.addr, nil } diff --git a/internal/distributed/proxy/client/client.go b/internal/distributed/proxy/client/client.go index b45f00efd7efc..b856f7102fccc 100644 --- a/internal/distributed/proxy/client/client.go +++ b/internal/distributed/proxy/client/client.go @@ -63,7 +63,7 @@ func (c *Client) newGrpcClient(cc *grpc.ClientConn) proxypb.ProxyClient { return proxypb.NewProxyClient(cc) } -func (c *Client) getAddr() (string, error) { +func (c *Client) getAddr(_ context.Context) (string, error) { return c.addr, nil } diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index cd9b928d733a6..ca9521cedb13f 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -30,17 +30,30 @@ import ( "sync" "time" + "github.com/milvus-io/milvus/pkg/util/merr" + + "google.golang.org/grpc/credentials" + + management "github.com/milvus-io/milvus/internal/http" + "github.com/milvus-io/milvus/internal/proxy/accesslog" + "github.com/milvus-io/milvus/internal/util/componentutil" + "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util/interceptor" + "github.com/milvus-io/milvus/pkg/util/metricsinfo" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/soheilhy/cmux" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "github.com/gin-gonic/gin" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" - "github.com/soheilhy/cmux" clientv3 "go.etcd.io/etcd/client/v3" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/atomic" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" @@ -53,22 +66,14 @@ import ( qcc "github.com/milvus-io/milvus/internal/distributed/querycoord/client" rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" "github.com/milvus-io/milvus/internal/distributed/utils" - management "github.com/milvus-io/milvus/internal/http" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proxy" - "github.com/milvus-io/milvus/internal/proxy/accesslog" "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/internal/util/componentutil" - "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -509,7 +514,7 @@ func (s *Server) init() error { if s.rootCoordClient == nil { var err error log.Debug("create RootCoord client for Proxy") - s.rootCoordClient, err = rcc.NewClient(s.ctx, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) + s.rootCoordClient, err = rcc.NewClient(s.ctx, grpcclient.NewRawEntryProvider(etcdCli, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.RootCoordRole)) if err != nil { log.Warn("failed to create RootCoord client for Proxy", zap.Error(err)) return err @@ -531,7 +536,7 @@ func (s *Server) init() error { if s.dataCoordClient == nil { var err error log.Debug("create DataCoord client for Proxy") - s.dataCoordClient, err = dcc.NewClient(s.ctx, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) + s.dataCoordClient, err = dcc.NewClient(s.ctx, grpcclient.NewRawEntryProvider(etcdCli, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.DataCoordRole)) if err != nil { log.Warn("failed to create DataCoord client for Proxy", zap.Error(err)) return err @@ -553,7 +558,7 @@ func (s *Server) init() error { if s.queryCoordClient == nil { var err error log.Debug("create QueryCoord client for Proxy") - s.queryCoordClient, err = qcc.NewClient(s.ctx, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) + s.queryCoordClient, err = qcc.NewClient(s.ctx, grpcclient.NewRawEntryProvider(etcdCli, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.QueryCoordRole)) if err != nil { log.Warn("failed to create QueryCoord client for Proxy", zap.Error(err)) return err diff --git a/internal/distributed/querycoord/client/client.go b/internal/distributed/querycoord/client/client.go index 22eedf074a068..385474ae423dd 100644 --- a/internal/distributed/querycoord/client/client.go +++ b/internal/distributed/querycoord/client/client.go @@ -18,9 +18,7 @@ package grpcquerycoordclient import ( "context" - "fmt" - clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" "google.golang.org/grpc" @@ -29,7 +27,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/grpcclient" - "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -42,44 +39,36 @@ var Params *paramtable.ComponentParam = paramtable.Get() // Client is the grpc client of QueryCoord. type Client struct { grpcClient grpcclient.GrpcClient[querypb.QueryCoordClient] - sess *sessionutil.Session + sp grpcclient.ServiceProvider } // NewClient creates a client for QueryCoord grpc call. -func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) (*Client, error) { - sess := sessionutil.NewSession(ctx, metaRoot, etcdCli) - if sess == nil { - err := fmt.Errorf("new session error, maybe can not connect to etcd") - log.Debug("QueryCoordClient NewClient failed", zap.Error(err)) - return nil, err - } +func NewClient(ctx context.Context, sp grpcclient.ServiceProvider) (*Client, error) { config := &Params.QueryCoordGrpcClientCfg client := &Client{ grpcClient: grpcclient.NewClientBase[querypb.QueryCoordClient](config, "milvus.proto.query.QueryCoord"), - sess: sess, + sp: sp, } client.grpcClient.SetRole(typeutil.QueryCoordRole) client.grpcClient.SetGetAddrFunc(client.getQueryCoordAddr) client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) - client.grpcClient.SetSession(sess) return client, nil } -func (c *Client) getQueryCoordAddr() (string, error) { - key := c.grpcClient.GetRole() - msess, _, err := c.sess.GetSessions(key) +// Init initializes QueryCoord's grpc client. +func (c *Client) Init() error { + return nil +} + +func (c *Client) getQueryCoordAddr(ctx context.Context) (string, error) { + addr, serverID, err := c.sp.GetServiceEntry(ctx) if err != nil { - log.Debug("QueryCoordClient GetSessions failed", zap.Error(err)) + log.Warn("QueryCoordClient get service entry failed", zap.Error(err)) return "", err } - ms, ok := msess[key] - if !ok { - log.Debug("QueryCoordClient msess key not existed", zap.Any("key", key)) - return "", fmt.Errorf("find no available querycoord, check querycoord state") - } - c.grpcClient.SetNodeID(ms.ServerID) - return ms.Address, nil + c.grpcClient.SetNodeID(serverID) + return addr, nil } func (c *Client) newGrpcClient(cc *grpc.ClientConn) querypb.QueryCoordClient { @@ -202,7 +191,7 @@ func (c *Client) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncN req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), ) return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) { return client.SyncNewCreatedPartition(ctx, req) diff --git a/internal/distributed/querycoord/client/client_test.go b/internal/distributed/querycoord/client/client_test.go index 248d19d8c561b..1ff4f3c0527a7 100644 --- a/internal/distributed/querycoord/client/client_test.go +++ b/internal/distributed/querycoord/client/client_test.go @@ -25,16 +25,18 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/mock" "go.uber.org/zap" - "google.golang.org/grpc" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proxy" - "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/pkg/util/etcd" ) func TestMain(m *testing.M) { @@ -67,7 +69,7 @@ func Test_NewClient(t *testing.T) { Params.EtcdCfg.EtcdTLSCACert.GetValue(), Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) assert.NoError(t, err) - client, err := NewClient(ctx, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) + client, err := NewClient(ctx, grpcclient.NewRawEntryProvider(etcdCli, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.QueryCoordRole)) assert.NoError(t, err) assert.NotNil(t, client) diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index 4281358f12312..df884d4681c56 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -24,6 +24,10 @@ import ( "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + "github.com/milvus-io/milvus/internal/util/componentutil" + "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/pkg/tracer" "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -41,10 +45,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" qc "github.com/milvus-io/milvus/internal/querycoordv2" "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/internal/util/componentutil" - "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -52,6 +53,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tikv" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // Server is the grpc server of QueryCoord. @@ -153,7 +155,7 @@ func (s *Server) init() error { // --- Master Server Client --- if s.rootCoord == nil { - s.rootCoord, err = rcc.NewClient(s.loopCtx, qc.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli) + s.rootCoord, err = rcc.NewClient(s.loopCtx, grpcclient.NewRawEntryProvider(s.etcdCli, qc.Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.RootCoordRole)) if err != nil { log.Error("QueryCoord try to new RootCoord client failed", zap.Error(err)) panic(err) @@ -175,7 +177,7 @@ func (s *Server) init() error { // --- Data service client --- if s.dataCoord == nil { - s.dataCoord, err = dcc.NewClient(s.loopCtx, qc.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli) + s.dataCoord, err = dcc.NewClient(s.loopCtx, grpcclient.NewRawEntryProvider(s.etcdCli, qc.Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.DataCoordRole)) if err != nil { log.Error("QueryCoord try to new DataCoord client failed", zap.Error(err)) panic(err) diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index a7291a9ad50f7..021ad7c45e783 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -67,7 +67,7 @@ func (c *Client) newGrpcClient(cc *grpc.ClientConn) querypb.QueryNodeClient { return querypb.NewQueryNodeClient(cc) } -func (c *Client) getAddr() (string, error) { +func (c *Client) getAddr(_ context.Context) (string, error) { return c.addr, nil } diff --git a/internal/distributed/rootcoord/client/client.go b/internal/distributed/rootcoord/client/client.go index 1f07f904bd631..09ce09724ddf1 100644 --- a/internal/distributed/rootcoord/client/client.go +++ b/internal/distributed/rootcoord/client/client.go @@ -18,9 +18,7 @@ package grpcrootcoordclient import ( "context" - "fmt" - clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" "google.golang.org/grpc" grpcCodes "google.golang.org/grpc/codes" @@ -32,7 +30,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/util/grpcclient" - "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -45,30 +42,22 @@ var Params *paramtable.ComponentParam = paramtable.Get() // Client grpc client type Client struct { grpcClient grpcclient.GrpcClient[rootcoordpb.RootCoordClient] - sess *sessionutil.Session + sp grpcclient.ServiceProvider } // NewClient create root coordinator client with specified etcd info and timeout // ctx execution control context -// metaRoot is the path in etcd for root coordinator registration -// etcdEndpoints are the address list for etcd end points +// sp is the service entry information provider // timeout is default setting for each grpc call -func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) (*Client, error) { - sess := sessionutil.NewSession(ctx, metaRoot, etcdCli) - if sess == nil { - err := fmt.Errorf("new session error, maybe can not connect to etcd") - log.Debug("QueryCoordClient NewClient failed", zap.Error(err)) - return nil, err - } +func NewClient(ctx context.Context, sp grpcclient.ServiceProvider) (*Client, error) { config := &Params.RootCoordGrpcClientCfg client := &Client{ grpcClient: grpcclient.NewClientBase[rootcoordpb.RootCoordClient](config, "milvus.proto.rootcoord.RootCoord"), - sess: sess, + sp: sp, } client.grpcClient.SetRole(typeutil.RootCoordRole) client.grpcClient.SetGetAddrFunc(client.getRootCoordAddr) client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) - client.grpcClient.SetSession(sess) return client, nil } @@ -78,24 +67,14 @@ func (c *Client) newGrpcClient(cc *grpc.ClientConn) rootcoordpb.RootCoordClient return rootcoordpb.NewRootCoordClient(cc) } -func (c *Client) getRootCoordAddr() (string, error) { - key := c.grpcClient.GetRole() - msess, _, err := c.sess.GetSessions(key) +func (c *Client) getRootCoordAddr(ctx context.Context) (string, error) { + addr, serverID, err := c.sp.GetServiceEntry(ctx) if err != nil { - log.Debug("RootCoordClient GetSessions failed", zap.Any("key", key)) + log.Warn("RootCoordClient get service entry failed", zap.Error(err)) return "", err } - ms, ok := msess[key] - if !ok { - log.Warn("RootCoordClient mess key not exist", zap.Any("key", key)) - return "", fmt.Errorf("find no available rootcoord, check rootcoord state") - } - log.Debug("RootCoordClient GetSessions success", - zap.String("address", ms.Address), - zap.Int64("serverID", ms.ServerID), - ) - c.grpcClient.SetNodeID(ms.ServerID) - return ms.Address, nil + c.grpcClient.SetNodeID(serverID) + return addr, nil } // Close terminate grpc connection @@ -600,7 +579,7 @@ func (c *Client) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabase in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), ) ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { @@ -619,7 +598,7 @@ func (c *Client) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequ in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), ) ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { @@ -638,7 +617,7 @@ func (c *Client) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRe in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), - commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), ) ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { diff --git a/internal/distributed/rootcoord/client/client_test.go b/internal/distributed/rootcoord/client/client_test.go index ca5aa24ca4613..368ae5e11a084 100644 --- a/internal/distributed/rootcoord/client/client_test.go +++ b/internal/distributed/rootcoord/client/client_test.go @@ -25,16 +25,19 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/mock" "github.com/stretchr/testify/assert" "go.uber.org/zap" - "google.golang.org/grpc" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" - "github.com/milvus-io/milvus/internal/proxy" - "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/pkg/util/etcd" ) func TestMain(m *testing.M) { @@ -66,7 +69,7 @@ func Test_NewClient(t *testing.T) { Params.EtcdCfg.EtcdTLSCACert.GetValue(), Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) assert.NoError(t, err) - client, err := NewClient(ctx, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), etcdCli) + client, err := NewClient(ctx, grpcclient.NewRawEntryProvider(etcdCli, proxy.Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.RootCoordRole)) assert.NoError(t, err) assert.NotNil(t, client) diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index e7be73f655c70..bba0c5470808f 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -24,6 +24,12 @@ import ( "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/interceptor" + "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -42,16 +48,12 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/tracer" - "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/interceptor" - "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tikv" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // Server grpc wrapper @@ -126,7 +128,7 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) func (s *Server) setClient() { s.newDataCoordClient = func(etcdMetaRoot string, etcdCli *clientv3.Client) types.DataCoordClient { - dsClient, err := dcc.NewClient(s.ctx, etcdMetaRoot, etcdCli) + dsClient, err := dcc.NewClient(s.ctx, grpcclient.NewRawEntryProvider(etcdCli, etcdMetaRoot, typeutil.DataCoordRole)) if err != nil { panic(err) } @@ -134,7 +136,7 @@ func (s *Server) setClient() { } s.newQueryCoordClient = func(metaRootPath string, etcdCli *clientv3.Client) types.QueryCoordClient { - qsClient, err := qcc.NewClient(s.ctx, metaRootPath, etcdCli) + qsClient, err := qcc.NewClient(s.ctx, grpcclient.NewRawEntryProvider(etcdCli, metaRootPath, typeutil.QueryCoordRole)) if err != nil { panic(err) } diff --git a/internal/proto/planpb/plan.pb.go b/internal/proto/planpb/plan.pb.go index 06de0040e91da..2edfb4c92706d 100644 --- a/internal/proto/planpb/plan.pb.go +++ b/internal/proto/planpb/plan.pb.go @@ -237,6 +237,7 @@ func (BinaryExpr_BinaryOp) EnumDescriptor() ([]byte, []int) { type GenericValue struct { // Types that are valid to be assigned to Val: + // // *GenericValue_BoolVal // *GenericValue_Int64Val // *GenericValue_FloatVal @@ -1297,6 +1298,7 @@ var xxx_messageInfo_AlwaysTrueExpr proto.InternalMessageInfo type Expr struct { // Types that are valid to be assigned to Expr: + // // *Expr_TermExpr // *Expr_UnaryExpr // *Expr_BinaryExpr @@ -1668,6 +1670,7 @@ func (m *QueryPlanNode) GetLimit() int64 { type PlanNode struct { // Types that are valid to be assigned to Node: + // // *PlanNode_VectorAnns // *PlanNode_Predicates // *PlanNode_Query diff --git a/internal/proto/proxypb/proxy.pb.go b/internal/proto/proxypb/proxy.pb.go index 60ba4e6a11df1..c5cf24fd78b7e 100644 --- a/internal/proto/proxypb/proxy.pb.go +++ b/internal/proto/proxypb/proxy.pb.go @@ -29,8 +29,9 @@ const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package type InvalidateCollMetaCacheRequest struct { // MsgType: - // DropCollection -> {meta cache, dml channels} - // Other -> {meta cache} + // + // DropCollection -> {meta cache, dml channels} + // Other -> {meta cache} Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` DbName string `protobuf:"bytes,2,opt,name=db_name,json=dbName,proto3" json:"db_name,omitempty"` CollectionName string `protobuf:"bytes,3,opt,name=collection_name,json=collectionName,proto3" json:"collection_name,omitempty"` diff --git a/internal/proto/rootcoordpb/root_coord.pb.go b/internal/proto/rootcoordpb/root_coord.pb.go index 7d16956e9bd7f..ff74594ed2d2e 100644 --- a/internal/proto/rootcoordpb/root_coord.pb.go +++ b/internal/proto/rootcoordpb/root_coord.pb.go @@ -793,28 +793,28 @@ type RootCoordClient interface { GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) - //* + // * // @brief This method is used to create collection // // @param CreateCollectionRequest, use to provide collection information to be created. // // @return Status CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) - //* + // * // @brief This method is used to delete collection. // // @param DropCollectionRequest, collection name is going to be deleted. // // @return Status DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) - //* + // * // @brief This method is used to test collection existence. // // @param HasCollectionRequest, collection name is going to be tested. // // @return BoolResponse HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) - //* + // * // @brief This method is used to get collection schema. // // @param DescribeCollectionRequest, target collection name. @@ -825,28 +825,28 @@ type RootCoordClient interface { CreateAlias(ctx context.Context, in *milvuspb.CreateAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) AlterAlias(ctx context.Context, in *milvuspb.AlterAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) - //* + // * // @brief This method is used to list all collections. // // @return StringListResponse, collection name list ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) AlterCollection(ctx context.Context, in *milvuspb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) - //* + // * // @brief This method is used to create partition // // @return Status CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) - //* + // * // @brief This method is used to drop partition // // @return Status DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) - //* + // * // @brief This method is used to test partition existence. // // @return BoolResponse HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) - //* + // * // @brief This method is used to show partition information // // @param ShowPartitionRequest, target collection name. @@ -854,7 +854,7 @@ type RootCoordClient interface { // @return StringListResponse ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) ShowPartitionsInternal(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) - // rpc DescribeSegment(milvus.DescribeSegmentRequest) returns (milvus.DescribeSegmentResponse) {} + // rpc DescribeSegment(milvus.DescribeSegmentRequest) returns (milvus.DescribeSegmentResponse) {} ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) AllocTimestamp(ctx context.Context, in *AllocTimestampRequest, opts ...grpc.CallOption) (*AllocTimestampResponse, error) AllocID(ctx context.Context, in *AllocIDRequest, opts ...grpc.CallOption) (*AllocIDResponse, error) @@ -1327,28 +1327,28 @@ type RootCoordServer interface { GetComponentStates(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) GetTimeTickChannel(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) GetStatisticsChannel(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) - //* + // * // @brief This method is used to create collection // // @param CreateCollectionRequest, use to provide collection information to be created. // // @return Status CreateCollection(context.Context, *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) - //* + // * // @brief This method is used to delete collection. // // @param DropCollectionRequest, collection name is going to be deleted. // // @return Status DropCollection(context.Context, *milvuspb.DropCollectionRequest) (*commonpb.Status, error) - //* + // * // @brief This method is used to test collection existence. // // @param HasCollectionRequest, collection name is going to be tested. // // @return BoolResponse HasCollection(context.Context, *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) - //* + // * // @brief This method is used to get collection schema. // // @param DescribeCollectionRequest, target collection name. @@ -1359,28 +1359,28 @@ type RootCoordServer interface { CreateAlias(context.Context, *milvuspb.CreateAliasRequest) (*commonpb.Status, error) DropAlias(context.Context, *milvuspb.DropAliasRequest) (*commonpb.Status, error) AlterAlias(context.Context, *milvuspb.AlterAliasRequest) (*commonpb.Status, error) - //* + // * // @brief This method is used to list all collections. // // @return StringListResponse, collection name list ShowCollections(context.Context, *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) AlterCollection(context.Context, *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) - //* + // * // @brief This method is used to create partition // // @return Status CreatePartition(context.Context, *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) - //* + // * // @brief This method is used to drop partition // // @return Status DropPartition(context.Context, *milvuspb.DropPartitionRequest) (*commonpb.Status, error) - //* + // * // @brief This method is used to test partition existence. // // @return BoolResponse HasPartition(context.Context, *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) - //* + // * // @brief This method is used to show partition information // // @param ShowPartitionRequest, target collection name. @@ -1388,7 +1388,7 @@ type RootCoordServer interface { // @return StringListResponse ShowPartitions(context.Context, *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) ShowPartitionsInternal(context.Context, *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) - // rpc DescribeSegment(milvus.DescribeSegmentRequest) returns (milvus.DescribeSegmentResponse) {} + // rpc DescribeSegment(milvus.DescribeSegmentRequest) returns (milvus.DescribeSegmentResponse) {} ShowSegments(context.Context, *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) AllocTimestamp(context.Context, *AllocTimestampRequest) (*AllocTimestampResponse, error) AllocID(context.Context, *AllocIDRequest) (*AllocIDResponse, error) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 7796de3560b5a..9e393a0917b4f 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -56,6 +56,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/internal/util/importutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/common" @@ -451,21 +452,21 @@ func TestProxy(t *testing.T) { go testServer.startGrpc(ctx, &wg, &p) assert.NoError(t, testServer.waitForGrpcReady()) - rootCoordClient, err := rcc.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli) + rootCoordClient, err := rcc.NewClient(ctx, grpcclient.NewRawEntryProvider(etcdcli, Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.RootCoordRole)) assert.NoError(t, err) err = componentutil.WaitForComponentHealthy(ctx, rootCoordClient, typeutil.RootCoordRole, attempts, sleepDuration) assert.NoError(t, err) proxy.SetRootCoordClient(rootCoordClient) log.Info("Proxy set root coordinator client") - dataCoordClient, err := grpcdatacoordclient2.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli) + dataCoordClient, err := grpcdatacoordclient2.NewClient(ctx, grpcclient.NewRawEntryProvider(etcdcli, Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.DataCoordRole)) assert.NoError(t, err) err = componentutil.WaitForComponentHealthy(ctx, dataCoordClient, typeutil.DataCoordRole, attempts, sleepDuration) assert.NoError(t, err) proxy.SetDataCoordClient(dataCoordClient) log.Info("Proxy set data coordinator client") - queryCoordClient, err := grpcquerycoordclient.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli) + queryCoordClient, err := grpcquerycoordclient.NewClient(ctx, grpcclient.NewRawEntryProvider(etcdcli, Params.EtcdCfg.MetaRootPath.GetValue(), typeutil.QueryCoordRole)) assert.NoError(t, err) err = componentutil.WaitForComponentHealthy(ctx, queryCoordClient, typeutil.QueryCoordRole, attempts, sleepDuration) assert.NoError(t, err) diff --git a/internal/registry/common/component.go b/internal/registry/common/component.go new file mode 100644 index 0000000000000..4026aca547103 --- /dev/null +++ b/internal/registry/common/component.go @@ -0,0 +1,30 @@ +package common + +import "github.com/milvus-io/milvus/pkg/util/typeutil" + +// specified type only +// mixcoord and standalone mode registers detail session in each component for now +var ( + coordinators = typeutil.NewSet( + typeutil.RootCoordRole, + typeutil.QueryCoordRole, + typeutil.DataCoordRole, + typeutil.IndexCoordRole, + ) + nodes = typeutil.NewSet( + typeutil.ProxyRole, + typeutil.QueryNodeRole, + typeutil.DataNodeRole, + typeutil.IndexNodeRole, + ) +) + +// IsCoordinator returns provided component type is in coordinators set. +func IsCoordinator(component string) bool { + return coordinators.Contain(component) +} + +// IsNode returns provided component type is in nodes set. +func IsNode(component string) bool { + return nodes.Contain(component) +} diff --git a/internal/registry/common/component_test.go b/internal/registry/common/component_test.go new file mode 100644 index 0000000000000..4bd218c0c0033 --- /dev/null +++ b/internal/registry/common/component_test.go @@ -0,0 +1,70 @@ +package common + +import ( + "testing" + + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/stretchr/testify/suite" +) + +type ComponentSuite struct { + suite.Suite +} + +func (s *ComponentSuite) TestIsCoordinator() { + type testCase struct { + componentName string + expected bool + } + + cases := []testCase{ + {typeutil.RootCoordRole, true}, + {typeutil.QueryCoordRole, true}, + {typeutil.IndexCoordRole, true}, + {typeutil.DataCoordRole, true}, + {typeutil.ProxyRole, false}, + {typeutil.QueryNodeRole, false}, + {typeutil.DataNodeRole, false}, + {typeutil.IndexNodeRole, false}, + {typeutil.EmbeddedRole, false}, + {typeutil.StandaloneRole, false}, + {"others", false}, + } + + for _, tc := range cases { + s.Run(tc.componentName, func() { + s.Equal(tc.expected, IsCoordinator(tc.componentName)) + }) + } +} + +func (s *ComponentSuite) TestIsNode() { + type testCase struct { + componentName string + expected bool + } + + cases := []testCase{ + {typeutil.RootCoordRole, false}, + {typeutil.QueryCoordRole, false}, + {typeutil.IndexCoordRole, false}, + {typeutil.DataCoordRole, false}, + {typeutil.ProxyRole, true}, + {typeutil.QueryNodeRole, true}, + {typeutil.DataNodeRole, true}, + {typeutil.IndexNodeRole, true}, + {typeutil.EmbeddedRole, false}, + {typeutil.StandaloneRole, false}, + {"others", false}, + } + + for _, tc := range cases { + s.Run(tc.componentName, func() { + s.Equal(tc.expected, IsNode(tc.componentName)) + }) + } +} + +func TestComponent(t *testing.T) { + suite.Run(t, new(ComponentSuite)) +} diff --git a/internal/registry/common/session.go b/internal/registry/common/session.go new file mode 100644 index 0000000000000..bd4b5648758b0 --- /dev/null +++ b/internal/registry/common/session.go @@ -0,0 +1,39 @@ +package common + +// ServiceEntryBase base implementation for ServiceEntry. +type ServiceEntryBase struct { + id int64 + addr string + compType string +} + +func (e *ServiceEntryBase) SetID(id int64) { + e.id = id +} + +func (e *ServiceEntryBase) SetAddr(addr string) { + e.addr = addr +} + +func (e *ServiceEntryBase) SetComponentType(compType string) { + e.compType = compType +} + +func (e *ServiceEntryBase) ID() int64 { + return e.id +} + +func (e *ServiceEntryBase) Addr() string { + return e.addr +} + +func (e *ServiceEntryBase) ComponentType() string { + return e.compType +} + +func NewServiceEntryBase(addr string, compType string) ServiceEntryBase { + return ServiceEntryBase{ + addr: addr, + compType: compType, + } +} diff --git a/internal/registry/etcd/mock_clientv3_kv_test.go b/internal/registry/etcd/mock_clientv3_kv_test.go new file mode 100644 index 0000000000000..32babd1feb473 --- /dev/null +++ b/internal/registry/etcd/mock_clientv3_kv_test.go @@ -0,0 +1,372 @@ +// Code generated by mockery v2.16.0. DO NOT EDIT. + +package etcd + +import ( + context "context" + + clientv3 "go.etcd.io/etcd/client/v3" + + mock "github.com/stretchr/testify/mock" +) + +// MockV3KV is an autogenerated mock type for the KV type +type MockV3KV struct { + mock.Mock +} + +type MockV3KV_Expecter struct { + mock *mock.Mock +} + +func (_m *MockV3KV) EXPECT() *MockV3KV_Expecter { + return &MockV3KV_Expecter{mock: &_m.Mock} +} + +// Compact provides a mock function with given fields: ctx, rev, opts +func (_m *MockV3KV) Compact(ctx context.Context, rev int64, opts ...clientv3.CompactOption) (*clientv3.CompactResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, rev) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *clientv3.CompactResponse + if rf, ok := ret.Get(0).(func(context.Context, int64, ...clientv3.CompactOption) *clientv3.CompactResponse); ok { + r0 = rf(ctx, rev, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*clientv3.CompactResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int64, ...clientv3.CompactOption) error); ok { + r1 = rf(ctx, rev, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockV3KV_Compact_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Compact' +type MockV3KV_Compact_Call struct { + *mock.Call +} + +// Compact is a helper method to define mock.On call +// - ctx context.Context +// - rev int64 +// - opts ...clientv3.CompactOption +func (_e *MockV3KV_Expecter) Compact(ctx interface{}, rev interface{}, opts ...interface{}) *MockV3KV_Compact_Call { + return &MockV3KV_Compact_Call{Call: _e.mock.On("Compact", + append([]interface{}{ctx, rev}, opts...)...)} +} + +func (_c *MockV3KV_Compact_Call) Run(run func(ctx context.Context, rev int64, opts ...clientv3.CompactOption)) *MockV3KV_Compact_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]clientv3.CompactOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(clientv3.CompactOption) + } + } + run(args[0].(context.Context), args[1].(int64), variadicArgs...) + }) + return _c +} + +func (_c *MockV3KV_Compact_Call) Return(_a0 *clientv3.CompactResponse, _a1 error) *MockV3KV_Compact_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Delete provides a mock function with given fields: ctx, key, opts +func (_m *MockV3KV) Delete(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.DeleteResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, key) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *clientv3.DeleteResponse + if rf, ok := ret.Get(0).(func(context.Context, string, ...clientv3.OpOption) *clientv3.DeleteResponse); ok { + r0 = rf(ctx, key, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*clientv3.DeleteResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, ...clientv3.OpOption) error); ok { + r1 = rf(ctx, key, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockV3KV_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockV3KV_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - opts ...clientv3.OpOption +func (_e *MockV3KV_Expecter) Delete(ctx interface{}, key interface{}, opts ...interface{}) *MockV3KV_Delete_Call { + return &MockV3KV_Delete_Call{Call: _e.mock.On("Delete", + append([]interface{}{ctx, key}, opts...)...)} +} + +func (_c *MockV3KV_Delete_Call) Run(run func(ctx context.Context, key string, opts ...clientv3.OpOption)) *MockV3KV_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]clientv3.OpOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(clientv3.OpOption) + } + } + run(args[0].(context.Context), args[1].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockV3KV_Delete_Call) Return(_a0 *clientv3.DeleteResponse, _a1 error) *MockV3KV_Delete_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Do provides a mock function with given fields: ctx, op +func (_m *MockV3KV) Do(ctx context.Context, op clientv3.Op) (clientv3.OpResponse, error) { + ret := _m.Called(ctx, op) + + var r0 clientv3.OpResponse + if rf, ok := ret.Get(0).(func(context.Context, clientv3.Op) clientv3.OpResponse); ok { + r0 = rf(ctx, op) + } else { + r0 = ret.Get(0).(clientv3.OpResponse) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, clientv3.Op) error); ok { + r1 = rf(ctx, op) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockV3KV_Do_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Do' +type MockV3KV_Do_Call struct { + *mock.Call +} + +// Do is a helper method to define mock.On call +// - ctx context.Context +// - op clientv3.Op +func (_e *MockV3KV_Expecter) Do(ctx interface{}, op interface{}) *MockV3KV_Do_Call { + return &MockV3KV_Do_Call{Call: _e.mock.On("Do", ctx, op)} +} + +func (_c *MockV3KV_Do_Call) Run(run func(ctx context.Context, op clientv3.Op)) *MockV3KV_Do_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(clientv3.Op)) + }) + return _c +} + +func (_c *MockV3KV_Do_Call) Return(_a0 clientv3.OpResponse, _a1 error) *MockV3KV_Do_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Get provides a mock function with given fields: ctx, key, opts +func (_m *MockV3KV) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, key) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *clientv3.GetResponse + if rf, ok := ret.Get(0).(func(context.Context, string, ...clientv3.OpOption) *clientv3.GetResponse); ok { + r0 = rf(ctx, key, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*clientv3.GetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, ...clientv3.OpOption) error); ok { + r1 = rf(ctx, key, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockV3KV_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type MockV3KV_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - opts ...clientv3.OpOption +func (_e *MockV3KV_Expecter) Get(ctx interface{}, key interface{}, opts ...interface{}) *MockV3KV_Get_Call { + return &MockV3KV_Get_Call{Call: _e.mock.On("Get", + append([]interface{}{ctx, key}, opts...)...)} +} + +func (_c *MockV3KV_Get_Call) Run(run func(ctx context.Context, key string, opts ...clientv3.OpOption)) *MockV3KV_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]clientv3.OpOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(clientv3.OpOption) + } + } + run(args[0].(context.Context), args[1].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockV3KV_Get_Call) Return(_a0 *clientv3.GetResponse, _a1 error) *MockV3KV_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Put provides a mock function with given fields: ctx, key, val, opts +func (_m *MockV3KV) Put(ctx context.Context, key string, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, key, val) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *clientv3.PutResponse + if rf, ok := ret.Get(0).(func(context.Context, string, string, ...clientv3.OpOption) *clientv3.PutResponse); ok { + r0 = rf(ctx, key, val, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*clientv3.PutResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, string, ...clientv3.OpOption) error); ok { + r1 = rf(ctx, key, val, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockV3KV_Put_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Put' +type MockV3KV_Put_Call struct { + *mock.Call +} + +// Put is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - val string +// - opts ...clientv3.OpOption +func (_e *MockV3KV_Expecter) Put(ctx interface{}, key interface{}, val interface{}, opts ...interface{}) *MockV3KV_Put_Call { + return &MockV3KV_Put_Call{Call: _e.mock.On("Put", + append([]interface{}{ctx, key, val}, opts...)...)} +} + +func (_c *MockV3KV_Put_Call) Run(run func(ctx context.Context, key string, val string, opts ...clientv3.OpOption)) *MockV3KV_Put_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]clientv3.OpOption, len(args)-3) + for i, a := range args[3:] { + if a != nil { + variadicArgs[i] = a.(clientv3.OpOption) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockV3KV_Put_Call) Return(_a0 *clientv3.PutResponse, _a1 error) *MockV3KV_Put_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +// Txn provides a mock function with given fields: ctx +func (_m *MockV3KV) Txn(ctx context.Context) clientv3.Txn { + ret := _m.Called(ctx) + + var r0 clientv3.Txn + if rf, ok := ret.Get(0).(func(context.Context) clientv3.Txn); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(clientv3.Txn) + } + } + + return r0 +} + +// MockV3KV_Txn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Txn' +type MockV3KV_Txn_Call struct { + *mock.Call +} + +// Txn is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockV3KV_Expecter) Txn(ctx interface{}) *MockV3KV_Txn_Call { + return &MockV3KV_Txn_Call{Call: _e.mock.On("Txn", ctx)} +} + +func (_c *MockV3KV_Txn_Call) Run(run func(ctx context.Context)) *MockV3KV_Txn_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockV3KV_Txn_Call) Return(_a0 clientv3.Txn) *MockV3KV_Txn_Call { + _c.Call.Return(_a0) + return _c +} + +type mockConstructorTestingTNewMockV3KV interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockV3KV creates a new instance of MockV3KV. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockV3KV(t mockConstructorTestingTNewMockV3KV) *MockV3KV { + mock := &MockV3KV{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/registry/etcd/raw_provider.go b/internal/registry/etcd/raw_provider.go new file mode 100644 index 0000000000000..8a3fcd5db733a --- /dev/null +++ b/internal/registry/etcd/raw_provider.go @@ -0,0 +1,71 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "context" + "encoding/json" + "path" + + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" + clientv3 "go.etcd.io/etcd/client/v3" +) + +type rawEntryProvider struct { + client clientv3.KV + metaPath string + role string +} + +// NewRawEntryProvider returns a RawEntryProvider for backward-compatibility. +// works for cooorinators service entry. +// Shall be removed when all service-discovery logic is unified. +func NewRawEntryProvider(cli clientv3.KV, metaPath string, role string) grpcclient.ServiceProvider { + return &rawEntryProvider{ + client: cli, + metaPath: metaPath, + role: role, + } +} + +// GetServiceEntry returns current service entry information. +func (p *rawEntryProvider) GetServiceEntry(ctx context.Context) (string, int64, error) { + key := path.Join(p.metaPath, common.DefaultServiceRoot, p.role) + + resp, err := p.client.Get(ctx, key) + if err != nil { + err := merr.WrapErrIoFailed(key, err.Error()) + return "", 0, err + } + + if len(resp.Kvs) != 1 { + err := merr.WrapErrIoKeyNotFound(key) + return "", 0, err + } + + kv := resp.Kvs[0] + session := etcdSession{} + err = json.Unmarshal(kv.Value, &session) + if err != nil { + err := merr.WrapErrIoKeyNotFound(key, err.Error()) + return "", 0, err + } + + return session.Addr(), session.ID(), nil +} diff --git a/internal/registry/etcd/raw_provider_test.go b/internal/registry/etcd/raw_provider_test.go new file mode 100644 index 0000000000000..f4d2a3b04be2f --- /dev/null +++ b/internal/registry/etcd/raw_provider_test.go @@ -0,0 +1,115 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "context" + "encoding/json" + "testing" + + "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/internal/registry/common" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" +) + +type RawEntryProviderSuite struct { + suite.Suite +} + +func (s *RawEntryProviderSuite) TestSuccessRun() { + kv := NewMockV3KV(s.T()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + session := &etcdSession{ + ServiceEntryBase: common.NewServiceEntryBase("addr", "role"), + stopping: atomic.NewBool(false), + } + session.SetID(100) + bs, err := json.Marshal(session) + s.NoError(err) + kv.EXPECT().Get(ctx, mock.AnythingOfType("string")).Return(&clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{ + { + Key: []byte("meta/session/role"), + Value: bs, + }, + }, + }, nil) + rp := NewRawEntryProvider(kv, "meta", "role") + addr, id, err := rp.GetServiceEntry(ctx) + s.NoError(err) + s.Equal(session.Addr(), addr) + s.Equal(session.ID(), id) +} + +func (s *RawEntryProviderSuite) TestGetError() { + kv := NewMockV3KV(s.T()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockedErr := errors.New("mocked") + kv.EXPECT().Get(ctx, mock.AnythingOfType("string")).Return(nil, mockedErr) + rp := NewRawEntryProvider(kv, "meta", "role") + _, _, err := rp.GetServiceEntry(ctx) + s.Error(err) + s.True(merr.ErrIoFailed.Is(err)) +} + +func (s *RawEntryProviderSuite) TestKeyNotFound() { + kv := NewMockV3KV(s.T()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + kv.EXPECT().Get(ctx, mock.AnythingOfType("string")).Return(&clientv3.GetResponse{}, nil) + rp := NewRawEntryProvider(kv, "meta", "role") + + _, _, err := rp.GetServiceEntry(ctx) + s.Error(err) + s.True(merr.ErrIoKeyNotFound.Is(err)) +} + +func (s *RawEntryProviderSuite) TestUnmarshalFailed() { + kv := NewMockV3KV(s.T()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + kv.EXPECT().Get(ctx, mock.AnythingOfType("string")).Return(&clientv3.GetResponse{ + Kvs: []*mvccpb.KeyValue{ + { + Key: []byte("meta/session/role"), + Value: nil, + }, + }, + }, nil) + rp := NewRawEntryProvider(kv, "meta", "role") + + _, _, err := rp.GetServiceEntry(ctx) + s.Error(err) + s.True(merr.ErrIoKeyNotFound.Is(err)) +} + +func TestRawEntryProvider(t *testing.T) { + suite.Run(t, new(RawEntryProviderSuite)) +} diff --git a/internal/registry/etcd/service_discovery.go b/internal/registry/etcd/service_discovery.go new file mode 100644 index 0000000000000..490ffc766135f --- /dev/null +++ b/internal/registry/etcd/service_discovery.go @@ -0,0 +1,256 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "context" + "encoding/json" + "fmt" + "path" + + "github.com/cockroachdb/errors" + grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord/client" + grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" + grpcindexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client" + grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client" + grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" + grpcrootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/registry" + "github.com/milvus-io/milvus/internal/registry/common" + "github.com/milvus-io/milvus/internal/registry/options" + "github.com/milvus-io/milvus/internal/types" + milvuscommon "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/samber/lo" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +type etcdServiceDiscovery struct { + client *clientv3.Client + metaRoot string +} + +// TODO add more option to constructor + +func NewEtcdServiceDiscovery(cli *clientv3.Client, metaRoot string) registry.ServiceDiscovery { + return &etcdServiceDiscovery{ + client: cli, + metaRoot: metaRoot, + } +} + +func (s *etcdServiceDiscovery) getServices(ctx context.Context, component string) ([]registry.ServiceEntry, int64, error) { + prefix := path.Join(s.metaRoot, milvuscommon.DefaultServiceRoot, component) + resp, err := s.client.Get(ctx, prefix, clientv3.WithPrefix()) + if err != nil { + if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { + return nil, -1, errors.Wrapf(err, "context canceled when list %s service", component) + } + return nil, -1, merr.WrapErrIoFailed(prefix, err.Error()) + } + + result := make([]registry.ServiceEntry, 0, len(resp.Kvs)) + for _, kv := range resp.Kvs { + session := &etcdSession{} + err = json.Unmarshal(kv.Value, session) + if err != nil { + log.Warn("there maybe some corrupted session in etcd", + zap.String("key", string(kv.Key)), + zap.Error(err), + ) + continue + } + _, mapKey := path.Split(string(kv.Key)) + log.Debug("etcdServiceDiscovery GetSessions", + zap.String("prefix", prefix), + zap.String("key", mapKey), + zap.String("address", session.Addr()), + ) + result = append(result, session) + } + return result, resp.Header.GetRevision(), nil + +} + +func (s *etcdServiceDiscovery) GetServices(ctx context.Context, component string) ([]registry.ServiceEntry, error) { + result, _, err := s.getServices(ctx, component) + return result, err +} + +func (s *etcdServiceDiscovery) WatchServices(ctx context.Context, component string, opts ...options.WatchOption) ([]registry.ServiceEntry, registry.ServiceWatcher[registry.ServiceEntry], error) { + return listWatch(ctx, s, component, func(ctx context.Context, entry registry.ServiceEntry) (registry.ServiceEntry, error) { + return entry, nil + }) +} + +func (s *etcdServiceDiscovery) getCoordinatorEntry(ctx context.Context, component string) (string, int64, error) { + key := path.Join(s.metaRoot, milvuscommon.DefaultServiceRoot, component) + resp, err := s.client.Get(ctx, key) + if err != nil { + if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { + return "", 0, errors.Wrapf(err, "context canceled when list %s service", component) + } + return "", 0, merr.WrapErrIoFailed(key, err.Error()) + } + + if len(resp.Kvs) != 1 { + return "", 0, merr.WrapErrIoKeyNotFound(key, fmt.Sprintf("get %s session entry not found", component)) + } + kv := resp.Kvs[0] + session := etcdSession{} + err = json.Unmarshal(kv.Value, &session) + if err != nil { + log.Warn("there maybe some corrupted session in etcd", + zap.String("key", string(kv.Key)), + zap.Error(err), + ) + return "", 0, merr.WrapErrIoKeyNotFound(key, fmt.Sprintf("get %s session unmarshal failed, %s", component, err.Error())) + } + + return session.Addr(), session.ID(), nil +} + +// GetRootCoord returns RootCoord grpc client as types.RootCoord. +func (s *etcdServiceDiscovery) GetRootCoord(ctx context.Context) (types.RootCoord, error) { + return grpcrootcoordclient.NewClient(ctx, s.getServiceProvider(typeutil.RootCoordRole)) +} + +// GetQueryCoord returns QueryCoord grpc client as types.QueryCoord. +func (s *etcdServiceDiscovery) GetQueryCoord(ctx context.Context) (types.QueryCoord, error) { + return grpcquerycoordclient.NewClient(ctx, s.getServiceProvider(typeutil.QueryCoordRole)) +} + +// GetDataCoord returns DataCoord grpc client as types.DataCoord. +func (s *etcdServiceDiscovery) GetDataCoord(ctx context.Context) (types.DataCoord, error) { + return grpcdatacoordclient.NewClient(ctx, s.getServiceProvider(typeutil.DataCoordRole)) +} + +// WatchDataNode returns current DataNode instances and chan to watch. +func (s *etcdServiceDiscovery) WatchDataNode(ctx context.Context, opts ...options.WatchOption) ([]types.DataNode, registry.ServiceWatcher[types.DataNode], error) { + return listWatch(ctx, s, typeutil.DataNodeRole, func(ctx context.Context, entry registry.ServiceEntry) (types.DataNode, error) { + return grpcdatanodeclient.NewClient(ctx, entry.Addr(), entry.ID()) + }) +} + +// WatchQueryNode returns current QueryNode instance and chan to watch. +func (s *etcdServiceDiscovery) WatchQueryNode(ctx context.Context, opts ...options.WatchOption) ([]types.QueryNode, registry.ServiceWatcher[types.QueryNode], error) { + return listWatch(ctx, s, typeutil.QueryNodeRole, func(ctx context.Context, entry registry.ServiceEntry) (types.QueryNode, error) { + return grpcquerynodeclient.NewClient(ctx, entry.Addr(), entry.ID()) + }) +} + +func (s *etcdServiceDiscovery) WatchIndexNode(ctx context.Context, opts ...options.WatchOption) ([]types.IndexNode, registry.ServiceWatcher[types.IndexNode], error) { + return listWatch(ctx, s, typeutil.QueryNodeRole, func(ctx context.Context, entry registry.ServiceEntry) (types.IndexNode, error) { + return grpcindexnodeclient.NewClient(ctx, entry.Addr(), entry.ID(), paramtable.Get().DataCoordCfg.WithCredential.GetAsBool()) + }) +} + +func (s *etcdServiceDiscovery) WatchProxy(ctx context.Context, opts ...options.WatchOption) ([]types.Proxy, registry.ServiceWatcher[types.Proxy], error) { + return listWatch(ctx, s, typeutil.ProxyRole, func(ctx context.Context, entry registry.ServiceEntry) (types.Proxy, error) { + return grpcproxyclient.NewClient(ctx, entry.Addr(), entry.ID()) + }) +} + +func (s *etcdServiceDiscovery) RegisterRootCoord(ctx context.Context, rootcoord types.RootCoord, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return s.registerService(ctx, typeutil.RootCoordRole, addr, opts...) +} + +func (s *etcdServiceDiscovery) RegisterDataCoord(ctx context.Context, datacoord types.DataCoord, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return s.registerService(ctx, typeutil.DataCoordRole, addr, opts...) +} + +func (s *etcdServiceDiscovery) RegisterQueryCoord(ctx context.Context, querycoord types.QueryCoord, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return s.registerService(ctx, typeutil.QueryCoordRole, addr, opts...) +} + +func (s *etcdServiceDiscovery) RegisterProxy(ctx context.Context, proxy types.Proxy, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return s.registerService(ctx, typeutil.ProxyRole, addr, opts...) +} + +func (s *etcdServiceDiscovery) RegisterDataNode(ctx context.Context, datanode types.DataNode, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return s.registerService(ctx, typeutil.DataNodeRole, addr, opts...) +} + +func (s *etcdServiceDiscovery) RegisterIndexNode(ctx context.Context, indexnode types.IndexNode, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return s.registerService(ctx, typeutil.IndexNodeRole, addr, opts...) +} + +func (s *etcdServiceDiscovery) RegisterQueryNode(ctx context.Context, querynode types.QueryNode, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return s.registerService(ctx, typeutil.QueryNodeRole, addr, opts...) +} + +func (s *etcdServiceDiscovery) registerService(ctx context.Context, component string, addr string, opts ...options.RegisterOption) (registry.Session, error) { + var exclusive bool + switch { + case common.IsCoordinator(component): + exclusive = true + case common.IsNode(component): + exclusive = false + default: + return nil, merr.WrapErrParameterInvalid("legal component type", component) + } + + opt := options.DefaultSessionOpt() + for _, o := range opts { + o(&opt) + } + opt.Exclusive = exclusive + + return newEtcdSession(s.client, s.metaRoot, addr, component, opt), nil +} + +func (s *etcdServiceDiscovery) getServiceProvider(component string) serviceProvider { + return func(ctx context.Context) (string, int64, error) { + return s.getCoordinatorEntry(ctx, component) + } +} + +func listWatch[T any](ctx context.Context, s *etcdServiceDiscovery, component string, convert func(ctx context.Context, entry registry.ServiceEntry) (T, error)) ([]T, registry.ServiceWatcher[T], error) { + current, revision, err := s.getServices(ctx, component) + if err != nil { + return nil, nil, err + } + + prefix := path.Join(s.metaRoot, milvuscommon.DefaultServiceRoot, component) + etcdCh := s.client.Watch(ctx, prefix, clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)) + w := newWatcher[T](s, prefix, etcdCh, convert) + + return lo.FilterMap(current, func(entry registry.ServiceEntry, _ int) (T, bool) { + var empty T + addr := entry.Addr() + c, err := convert(ctx, entry) + if err != nil { + log.Warn("failed to create client", + zap.String("addr", addr), + zap.String("component", component), + zap.Error(err), + ) + return empty, false + } + return c, true + }), w, nil +} + +// serviceProvider wraps function to grpcclient.ServiceProvider interface. +type serviceProvider func(ctx context.Context) (string, int64, error) + +func (sp serviceProvider) GetServiceEntry(ctx context.Context) (string, int64, error) { return sp(ctx) } diff --git a/internal/registry/etcd/service_discovery_test.go b/internal/registry/etcd/service_discovery_test.go new file mode 100644 index 0000000000000..c50e150db71af --- /dev/null +++ b/internal/registry/etcd/service_discovery_test.go @@ -0,0 +1,481 @@ +package etcd + +import ( + "context" + "fmt" + "net" + "os" + "path" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/registry" + "github.com/milvus-io/milvus/internal/registry/options" + milvuscommon "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/server/v3/embed" + "go.etcd.io/etcd/server/v3/etcdserver/api/v3client" + "google.golang.org/grpc" +) + +type EtcdSDSuite struct { + // suite fields + etcd *embed.Etcd + tmpFolder string + + // test fields + prefix string + etcdCli *clientv3.Client + sd *etcdServiceDiscovery + sessions []registry.Session + + suite.Suite +} + +func (s *EtcdSDSuite) SetupSuite() { + paramtable.Init() + server, tmpFolder, err := etcd.StartTestEmbedEtcdServer() + s.Require().NoError(err) + s.etcd = server + s.tmpFolder = tmpFolder + +} + +func (s *EtcdSDSuite) TearDownSuite() { + s.etcd.Server.Stop() + os.RemoveAll(s.tmpFolder) +} + +func (s *EtcdSDSuite) SetupTest() { + s.prefix = funcutil.RandomString(6) + s.etcdCli = v3client.New(s.etcd.Server) + raw := NewEtcdServiceDiscovery(s.etcdCli, s.prefix) + var ok bool + s.sd, ok = raw.(*etcdServiceDiscovery) + s.Require().True(ok) +} + +func (s *EtcdSDSuite) TearDownTest() { + s.etcdCli.Delete(context.Background(), s.prefix, clientv3.WithPrefix()) + s.etcdCli.Close() +} + +func (s *EtcdSDSuite) TestGetServices() { + sessions := s.prepareDataset() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + for _, serverType := range typeutil.ServerTypeList() { + target := lo.Filter(sessions, func(session registry.Session, _ int) bool { + return session.ComponentType() == serverType + }) + m := lo.SliceToMap(target, func(session registry.Session) (int64, registry.Session) { + return session.ID(), session + }) + + s.Run(serverType, func() { + entries, err := s.sd.GetServices(ctx, serverType) + s.NoError(err) + s.Equal(len(target), len(entries)) + for _, entry := range entries { + s.Equal(serverType, entry.ComponentType()) + targetEntry, ok := m[entry.ID()] + if s.True(ok) { + s.Equal(targetEntry.Addr(), entry.Addr()) + } + } + }) + } + }) + + s.Run("ctx_canceld", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for _, serverType := range typeutil.ServerTypeList() { + s.Run(serverType, func() { + _, err := s.sd.GetServices(ctx, serverType) + s.Error(err) + }) + } + }) + +} + +type MockRCServer struct { + mock.Mock + *rootcoordpb.UnimplementedRootCoordServer +} + +func (s *MockRCServer) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := s.Called(ctx, req) + + var r0 *milvuspb.ComponentStates + if rf, ok := ret.Get(0).(func(ctx context.Context, req *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + + var r1 error + if rf, ok := ret.Get(1).(func(ctx context.Context, req *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type MockDCServer struct { + mock.Mock + *datapb.UnimplementedDataCoordServer +} + +func (s *MockDCServer) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := s.Called(ctx, req) + + var r0 *milvuspb.ComponentStates + if rf, ok := ret.Get(0).(func(ctx context.Context, req *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + + var r1 error + if rf, ok := ret.Get(1).(func(ctx context.Context, req *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type MockQCServer struct { + mock.Mock + *querypb.UnimplementedQueryCoordServer +} + +func (s *MockQCServer) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := s.Called(ctx, req) + + var r0 *milvuspb.ComponentStates + if rf, ok := ret.Get(0).(func(ctx context.Context, req *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + + var r1 error + if rf, ok := ret.Get(1).(func(ctx context.Context, req *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +func (s *EtcdSDSuite) TestGetRootCoord() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("not_available", func() { + rc, err := s.sd.GetRootCoord(ctx) + s.NoError(err) + + _, err = rc.GetComponentStates(ctx) + s.Error(err) + }) + + // start grpc server + port := funcutil.GetAvailablePort() + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + s.Require().NoError(err) + server := grpc.NewServer() + mockRC := &MockRCServer{} + rootcoordpb.RegisterRootCoordServer(server, mockRC) + + go func() { + server.Serve(lis) + }() + + session, err := s.sd.RegisterRootCoord(ctx, nil, fmt.Sprintf("localhost:%d", port)) + s.Require().NoError(err) + s.setupSession(ctx, session) + + s.Run("success", func() { + mockRC.On("GetComponentStates", mock.Anything, mock.AnythingOfType("*milvuspb.GetComponentStatesRequest")).Return(&milvuspb.ComponentStates{ + Status: &commonpb.Status{}, + }, nil) + rc, err := s.sd.GetRootCoord(ctx) + s.NoError(err) + + _, err = rc.GetComponentStates(ctx) + s.NoError(err) + + mockRC.AssertCalled(s.T(), "GetComponentStates", mock.Anything, mock.AnythingOfType("*milvuspb.GetComponentStatesRequest")) + }) + + s.Run("getEntry_ContextCancelled", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + rc, err := s.sd.GetRootCoord(ctx) + s.NoError(err) + + _, err = rc.GetComponentStates(ctx) + s.Error(err) + }) +} + +func (s *EtcdSDSuite) TestGetDataCoord() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + port := funcutil.GetAvailablePort() + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + s.Require().NoError(err) + server := grpc.NewServer() + mockDC := &MockDCServer{} + datapb.RegisterDataCoordServer(server, mockDC) + + go func() { + server.Serve(lis) + }() + + session, err := s.sd.RegisterDataCoord(ctx, nil, fmt.Sprintf("localhost:%d", port)) + s.Require().NoError(err) + err = session.Init(ctx) + s.Require().NoError(err) + err = session.Register(ctx) + s.Require().NoError(err) + + mockDC.On("GetComponentStates", mock.Anything, mock.AnythingOfType("*milvuspb.GetComponentStatesRequest")).Return(&milvuspb.ComponentStates{ + Status: &commonpb.Status{}, + }, nil) + dc, err := s.sd.GetDataCoord(ctx) + s.NoError(err) + + _, err = dc.GetComponentStates(ctx) + s.NoError(err) + + mockDC.AssertCalled(s.T(), "GetComponentStates", mock.Anything, mock.AnythingOfType("*milvuspb.GetComponentStatesRequest")) +} + +func (s *EtcdSDSuite) TestGetQueryCoord() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + port := funcutil.GetAvailablePort() + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + s.Require().NoError(err) + server := grpc.NewServer() + mockQC := &MockQCServer{} + querypb.RegisterQueryCoordServer(server, mockQC) + + go func() { + server.Serve(lis) + }() + + session, err := s.sd.RegisterQueryCoord(ctx, nil, fmt.Sprintf("localhost:%d", port)) + s.Require().NoError(err) + s.setupSession(ctx, session) + + mockQC.On("GetComponentStates", mock.Anything, mock.AnythingOfType("*milvuspb.GetComponentStatesRequest")).Return(&milvuspb.ComponentStates{ + Status: &commonpb.Status{}, + }, nil) + qc, err := s.sd.GetQueryCoord(ctx) + s.NoError(err) + + _, err = qc.GetComponentStates(ctx) + s.NoError(err) + + mockQC.AssertCalled(s.T(), "GetComponentStates", mock.Anything, mock.AnythingOfType("*milvuspb.GetComponentStatesRequest")) +} + +func (s *EtcdSDSuite) TestWatchService() { + sessions := s.prepareDataset() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for _, serverType := range []string{typeutil.ProxyRole, typeutil.DataNodeRole, typeutil.IndexNodeRole, typeutil.QueryNodeRole} { + + target := lo.Filter(sessions, func(session registry.Session, _ int) bool { + return session.ComponentType() == serverType + }) + m := lo.SliceToMap(target, func(session registry.Session) (int64, registry.Session) { + return session.ID(), session + }) + + s.Run(serverType, func() { + current, watcher, err := s.sd.WatchServices(ctx, serverType) + s.NoError(err) + s.Equal(len(target), len(current)) + for _, entry := range current { + s.Equal(serverType, entry.ComponentType()) + targetEntry, ok := m[entry.ID()] + if s.True(ok) { + s.Equal(targetEntry.Addr(), entry.Addr()) + } + } + + session, err := s.sd.registerService(ctx, serverType, "addr") + s.Require().NoError(err) + s.setupSession(ctx, session) + + select { + case evt := <-watcher.Watch(): + s.Equal(registry.SessionAddEvent, evt.EventType) + s.Equal(session.Addr(), evt.Entry.Addr()) + s.Equal(session.ID(), evt.Entry.ID()) + case <-time.After(time.Second): + s.FailNow("no watch event after 1 second") + } + }) + } +} + +func (s *EtcdSDSuite) TestListWatch() { + sessions := s.prepareDataset() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + for _, serverType := range []string{typeutil.ProxyRole, typeutil.DataNodeRole, typeutil.IndexNodeRole, typeutil.QueryNodeRole} { + + target := lo.Filter(sessions, func(session registry.Session, _ int) bool { + return session.ComponentType() == serverType + }) + m := lo.SliceToMap(target, func(session registry.Session) (int64, registry.Session) { + return session.ID(), session + }) + + s.Run(serverType, func() { + //current, watcher, err := s.sd.WatchServices(ctx, serverType) + current, watcher, err := listWatch(ctx, s.sd, serverType, func(ctx context.Context, entry registry.ServiceEntry) (registry.ServiceEntry, error) { + return entry, nil + }) + s.NoError(err) + s.Equal(len(target), len(current)) + for _, entry := range current { + s.Equal(serverType, entry.ComponentType()) + targetEntry, ok := m[entry.ID()] + if s.True(ok) { + s.Equal(targetEntry.Addr(), entry.Addr()) + } + } + + session, err := s.sd.registerService(ctx, serverType, "addr") + s.Require().NoError(err) + s.setupSession(ctx, session) + + select { + case evt := <-watcher.Watch(): + s.Equal(registry.SessionAddEvent, evt.EventType) + s.Equal(session.Addr(), evt.Entry.Addr()) + s.Equal(session.ID(), evt.Entry.ID()) + case <-time.After(time.Second): + s.FailNow("no watch event after 1 second") + } + }) + } + }) + + s.Run("ctx_canceled", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + for _, serverType := range []string{typeutil.ProxyRole, typeutil.DataNodeRole, typeutil.IndexNodeRole, typeutil.QueryNodeRole} { + s.Run(serverType, func() { + _, _, err := listWatch(ctx, s.sd, serverType, func(ctx context.Context, entry registry.ServiceEntry) (registry.ServiceEntry, error) { + return entry, nil + }) + s.Error(err) + }) + } + }) + + s.Run("converter_error", func() { + for _, serverType := range []string{typeutil.ProxyRole, typeutil.DataNodeRole, typeutil.IndexNodeRole, typeutil.QueryNodeRole} { + s.Run(serverType, func() { + current, _, err := listWatch(ctx, s.sd, serverType, func(ctx context.Context, entry registry.ServiceEntry) (registry.ServiceEntry, error) { + return nil, errors.New("mocked") + }) + s.NoError(err) + s.Equal(0, len(current)) + }) + } + }) +} + +func (s *EtcdSDSuite) setupSession(ctx context.Context, session registry.Session) { + err := session.Init(ctx) + s.Require().NoError(err) + err = session.Register(ctx) + s.Require().NoError(err) +} + +func (s *EtcdSDSuite) prepareDataset() []registry.Session { + var sessions []registry.Session + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + session, err := s.sd.RegisterProxy(ctx, nil, "proxy-addr", options.WithTriggerKill(false)) + s.Require().NoError(err) + s.setupSession(ctx, session) + sessions = append(sessions, session) + + session, err = s.sd.RegisterRootCoord(ctx, nil, "rc-addr", options.WithTriggerKill(false)) + s.Require().NoError(err) + s.setupSession(ctx, session) + sessions = append(sessions, session) + + session, err = s.sd.RegisterQueryCoord(ctx, nil, "qc-addr", options.WithTriggerKill(false)) + s.Require().NoError(err) + s.setupSession(ctx, session) + sessions = append(sessions, session) + + session, err = s.sd.RegisterDataCoord(ctx, nil, "dc-addr", options.WithTriggerKill(false)) + s.Require().NoError(err) + s.setupSession(ctx, session) + sessions = append(sessions, session) + + session, err = s.sd.RegisterDataNode(ctx, nil, "dn-addr", options.WithTriggerKill(false)) + s.Require().NoError(err) + s.setupSession(ctx, session) + sessions = append(sessions, session) + + session, err = s.sd.RegisterQueryNode(ctx, nil, "qn-addr", options.WithTriggerKill(false)) + s.Require().NoError(err) + s.setupSession(ctx, session) + sessions = append(sessions, session) + + session, err = s.sd.RegisterIndexNode(ctx, nil, "in-addr", options.WithTriggerKill(false)) + s.Require().NoError(err) + s.setupSession(ctx, session) + sessions = append(sessions, session) + + // garbage data + _, err = s.etcdCli.Put(ctx, path.Join(s.prefix, milvuscommon.DefaultServiceRoot, "proxy-9999"), "") + s.Require().NoError(err) + + return sessions +} + +func TestEtcdServiceDiscovery(t *testing.T) { + suite.Run(t, new(EtcdSDSuite)) +} diff --git a/internal/registry/etcd/session.go b/internal/registry/etcd/session.go new file mode 100644 index 0000000000000..4e1af8461d70a --- /dev/null +++ b/internal/registry/etcd/session.go @@ -0,0 +1,404 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "context" + "encoding/json" + "fmt" + "path" + "strconv" + "sync" + + "github.com/blang/semver/v4" + "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/internal/registry" + "github.com/milvus-io/milvus/internal/registry/common" + "github.com/milvus-io/milvus/internal/registry/options" + milvuscommon "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/retry" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +var _ registry.Session = (*etcdSession)(nil) + +type etcdSession struct { + common.ServiceEntryBase + + idOnce sync.Once + + // session info + exclusive bool + stopping *atomic.Bool + version semver.Version + client *clientv3.Client + leaseID clientv3.LeaseID + isStandby *atomic.Bool + keepaliveCancel context.CancelFunc + triggerKill bool + liveCh <-chan struct{} + closeCh <-chan struct{} + + // options + metaRoot string + useCustomConfig bool + sessionTTL int64 + sessionRetryTimes int64 + reuseNodeID bool + enableActiveStandBy bool +} + +func newEtcdSession(cli *clientv3.Client, metaRoot string, addr string, component string, + opt options.SessionOpt) *etcdSession { + return &etcdSession{ + ServiceEntryBase: common.NewServiceEntryBase(addr, component), + exclusive: opt.Exclusive, + client: cli, + metaRoot: metaRoot, + isStandby: atomic.NewBool(opt.StandBy), + version: milvuscommon.Version, + stopping: atomic.NewBool(false), + } +} + +// UnmarshalJSON unmarshal bytes to Session. +func (s *etcdSession) UnmarshalJSON(data []byte) error { + var raw struct { + ServerID int64 `json:"ServerID,omitempty"` + ServerName string `json:"ServerName,omitempty"` + Address string `json:"Address,omitempty"` + Exclusive bool `json:"Exclusive,omitempty"` + Stopping bool `json:"Stopping,omitempty"` + Version string `json:"Version"` + } + err := json.Unmarshal(data, &raw) + if err != nil { + return err + } + + if raw.Version != "" { + s.version, err = semver.Parse(raw.Version) + if err != nil { + return err + } + } + + s.SetID(raw.ServerID) + s.SetAddr(raw.Address) + s.SetComponentType(raw.ServerName) + s.exclusive = raw.Exclusive + s.stopping = atomic.NewBool(raw.Stopping) + return nil +} + +// MarshalJSON marshals session to bytes. +func (s *etcdSession) MarshalJSON() ([]byte, error) { + + verStr := s.version.String() + return json.Marshal(&struct { + ServerID int64 `json:"ServerID,omitempty"` + ServerName string `json:"ServerName,omitempty"` + Address string `json:"Address,omitempty"` + Exclusive bool `json:"Exclusive,omitempty"` + Stopping bool `json:"Stopping,omitempty"` + TriggerKill bool + Version string `json:"Version"` + }{ + ServerID: s.ID(), + ServerName: s.ComponentType(), + Address: s.Addr(), + Exclusive: s.exclusive, + Stopping: s.stopping.Load(), + TriggerKill: s.triggerKill, + Version: verStr, + }) + +} + +// Init will initialize serverID with CAS operation with etcd `ID` kv. +func (s *etcdSession) Init(ctx context.Context) error { + // use custom config, for migration tool only + if s.useCustomConfig { + return nil + } + s.checkIDExist(ctx) + serverID, err := s.getServerID(ctx) + if err != nil { + return err + } + s.SetID(serverID) + return nil +} + +func (s *etcdSession) checkIDExist(ctx context.Context) { + s.client.Txn(ctx).If( + clientv3.Compare( + clientv3.Version(path.Join(s.metaRoot, milvuscommon.DefaultServiceRoot, milvuscommon.DefaultIDKey)), + "=", + 0)). + Then(clientv3.OpPut(path.Join(s.metaRoot, milvuscommon.DefaultServiceRoot, milvuscommon.DefaultIDKey), "1")).Commit() +} + +// Register writes session entry in etcd paths with ttl lease. +// internally it will start `keepAlive` to keep alive with etcd. +func (s *etcdSession) Register(ctx context.Context) error { + ch, err := s.registerService(ctx) + if err != nil { + log.Error("session register failed", zap.Error(err)) + return err + } + s.liveCh = s.keepAlive(ch) + return nil +} + +func (s *etcdSession) Revoke(ctx context.Context) error { + if s == nil { + return nil + } + + if s.client == nil || s.leaseID == 0 { + // TODO audit error type here + return merr.WrapErrParameterInvalid("valid session", "session not valid") + } + + _, err := s.client.Revoke(ctx, s.leaseID) + if err != nil { + if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { + return errors.Wrapf(err, "context canceled when revoking component %s serverID:%d", s.ComponentType(), s.ID()) + } + return merr.WrapErrIoFailed(fmt.Sprintf("%s-%d", s.ComponentType(), s.ID()), err.Error()) + } + + return nil +} + +// Stop marks session as `Stopping` state for graceful stop. +func (s *etcdSession) Stop(ctx context.Context) error { + if s == nil || s.client == nil || s.leaseID == 0 { + return merr.WrapErrParameterInvalid("valid session", "session not valid") + } + + completeKey := s.getCompleteKey() + log := log.Ctx(ctx).With( + zap.String("key", completeKey), + ) + resp, err := s.client.Get(ctx, completeKey, clientv3.WithCountOnly()) + if err != nil { + log.Error("fail to get the session", zap.Error(err)) + return err + } + if resp.Count == 0 { + return nil + } + s.stopping.Store(true) + sessionJSON, err := json.Marshal(s) + if err != nil { + log.Error("fail to marshal the session") + return err + } + _, err = s.client.Put(ctx, completeKey, string(sessionJSON), clientv3.WithLease(s.leaseID)) + if err != nil { + log.Error("fail to update the session to stopping state") + return err + } + return nil +} + +func (s *etcdSession) getServerID(ctx context.Context) (int64, error) { + var nodeID int64 + var err error + s.idOnce.Do(func() { + log.Debug("getServerID", zap.Bool("reuse", s.reuseNodeID)) + if s.reuseNodeID { + // Notice, For standalone, all process share the same nodeID. + if nodeID = paramtable.GetNodeID(); nodeID != 0 { + return + } + } + nodeID, err = s.getServerIDWithKey(ctx, milvuscommon.DefaultIDKey) + if err != nil { + return + } + if s.reuseNodeID { + paramtable.SetNodeID(nodeID) + } + }) + return nodeID, err +} + +func (s *etcdSession) getServerIDWithKey(ctx context.Context, key string) (int64, error) { + log := log.Ctx(ctx).With(zap.String("key", key)) + completeKey := path.Join(s.metaRoot, milvuscommon.DefaultServiceRoot, key) + for { + getResp, err := s.client.Get(ctx, completeKey) + if err != nil { + log.Warn("Session get etcd key error", zap.Error(err)) + return -1, err + } + if getResp.Count <= 0 { + log.Warn("Session there is no value") + continue + } + value := string(getResp.Kvs[0].Value) + valueInt, err := strconv.ParseInt(value, 10, 64) + if err != nil { + log.Warn("Session ParseInt error", zap.Error(err)) + continue + } + nextValue := strconv.FormatInt(valueInt+1, 10) + txnResp, err := s.client.Txn(ctx).If( + clientv3.Compare( + clientv3.Value(completeKey), + "=", + value)). + Then(clientv3.OpPut(completeKey, nextValue)).Commit() + if err != nil { + log.Warn("Session Txn failed", zap.Error(err)) + return -1, err + } + + if !txnResp.Succeeded { + log.Warn("Session Txn unsuccessful") + continue + } + log.Info("Session get serverID success", zap.Int64("ServerId", valueInt), zap.String("completeKey", completeKey)) + return valueInt, nil + } +} + +func (s *etcdSession) getCompleteKey() string { + key := s.ComponentType() + if !s.exclusive || (s.enableActiveStandBy && s.isStandby.Load()) { + key = fmt.Sprintf("%s-%d", key, s.ID()) + } + return path.Join(s.metaRoot, milvuscommon.DefaultServiceRoot, key) +} + +// registerService registers the service to etcd so that other services +// can find that the service is online and issue subsequent operations +// RegisterService will save a key-value in etcd +// key: metaRootPath + "/services" + "/ServerName-ServerID" +// value: json format +// +// { +// ServerID int64 `json:"ServerID,omitempty"` +// ServerName string `json:"ServerName,omitempty"` +// Address string `json:"Address,omitempty"` +// Exclusive bool `json:"Exclusive,omitempty"` +// } +// +// Exclusive means whether this service can exist two at the same time, if so, +// it is false. Otherwise, set it to true. +func (s *etcdSession) registerService(ctx context.Context) (<-chan *clientv3.LeaseKeepAliveResponse, error) { + log := log.Ctx(ctx).With( + zap.String("ComponentType", s.ComponentType()), + zap.Int64("ServerID", s.ID()), + ) + + if s.enableActiveStandBy { + s.isStandby.Store(true) + } + completeKey := s.getCompleteKey() + var ch <-chan *clientv3.LeaseKeepAliveResponse + log.Debug("service begin to register to etcd", zap.String("key", completeKey)) + + ttl := s.sessionTTL + retryTimes := s.sessionRetryTimes + if !s.useCustomConfig { + ttl = paramtable.Get().CommonCfg.SessionTTL.GetAsInt64() + retryTimes = paramtable.Get().CommonCfg.SessionRetryTimes.GetAsInt64() + } + + registerFn := func() error { + resp, err := s.client.Grant(ctx, ttl) + if err != nil { + log.Error("register service", zap.Error(err)) + return err + } + s.leaseID = resp.ID + + sessionJSON, err := json.Marshal(s) + if err != nil { + return err + } + + txnResp, err := s.client.Txn(ctx).If( + clientv3.Compare( + clientv3.Version(completeKey), + "=", + 0)). + Then(clientv3.OpPut(completeKey, string(sessionJSON), clientv3.WithLease(resp.ID))).Commit() + + if err != nil { + log.Warn("compare and swap error, maybe the key has already been registered", zap.Error(err)) + return err + } + + if !txnResp.Succeeded { + return fmt.Errorf("function CompareAndSwap error for compare is false for key: %s", s.ComponentType()) + } + log.Debug("put session key into etcd", zap.String("key", completeKey), zap.String("value", string(sessionJSON))) + + keepAliveCtx, keepAliveCancel := context.WithCancel(context.Background()) + s.keepaliveCancel = keepAliveCancel + + ch, err = s.client.KeepAlive(keepAliveCtx, resp.ID) + if err != nil { + log.Warn("go error during keeping alive with etcd", zap.Error(err)) + return err + } + log.Info("Service registered successfully") + return nil + } + err := retry.Do(ctx, registerFn, retry.Attempts(uint(retryTimes))) + if err != nil { + return nil, err + } + return ch, nil +} + +// keepAlive processes the response of etcd keepAlive interface +// If keepAlive fails for unexpected error, it will send a signal to the channel. +func (s *etcdSession) keepAlive(ch <-chan *clientv3.LeaseKeepAliveResponse) (failChannel <-chan struct{}) { + failCh := make(chan struct{}) + go func() { + for { + select { + case <-s.closeCh: + log.Info("etcd session quit") + return + case resp, ok := <-ch: + if !ok { + log.Warn("session keepalive channel closed") + close(failCh) + return + } + if resp == nil { + log.Warn("session keepalive response failed") + close(failCh) + return + } + } + } + }() + return failCh +} diff --git a/internal/registry/etcd/session_keepalive.go b/internal/registry/etcd/session_keepalive.go new file mode 100644 index 0000000000000..2db036757a187 --- /dev/null +++ b/internal/registry/etcd/session_keepalive.go @@ -0,0 +1 @@ +package etcd diff --git a/internal/registry/etcd/watcher.go b/internal/registry/etcd/watcher.go new file mode 100644 index 0000000000000..9e84d47be4102 --- /dev/null +++ b/internal/registry/etcd/watcher.go @@ -0,0 +1,164 @@ +package etcd + +import ( + "context" + "encoding/json" + "path" + "sync" + + "github.com/milvus-io/milvus/internal/registry" + milvuscommon "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "go.etcd.io/etcd/api/v3/mvccpb" + v3rpc "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +// Rewatch defines the behavior outer session watch handles ErrCompacted +// it should process the current full list of session +// and returns err if meta error or anything else goes wrong +type Rewatch func(sessions map[string]registry.ServiceEntry) error +type watcher[T any] struct { + s *etcdServiceDiscovery + rch clientv3.WatchChan + eventCh chan registry.SessionEvent[T] + closeCh chan struct{} + prefix string + rewatch Rewatch + validate func(registry.ServiceEntry) bool + convert func(context.Context, registry.ServiceEntry) (T, error) + stopOnce sync.Once +} + +func newWatcher[T any](s *etcdServiceDiscovery, prefix string, etcdCh clientv3.WatchChan, convert func(context.Context, registry.ServiceEntry) (T, error)) *watcher[T] { + w := &watcher[T]{ + s: s, + eventCh: make(chan registry.SessionEvent[T], 100), + rch: etcdCh, + prefix: prefix, + // rewatch: rewatch, + validate: func(s registry.ServiceEntry) bool { return true }, + convert: convert, + } + w.start() + return w +} + +// Watch returns the internal event channel, implementing `registry.ServiceWatcher` +func (w *watcher[T]) Watch() <-chan registry.SessionEvent[T] { + return w.eventCh +} + +func (w *watcher[T]) Stop() { + w.stopOnce.Do(func() { + close(w.closeCh) + }) +} + +func (w *watcher[T]) start() { + go func() { + for { + select { + case <-w.closeCh: + return + case wresp, ok := <-w.rch: + if !ok { + log.Warn("session watch channel closed") + return + } + w.handleWatchResponse(wresp) + } + } + }() +} + +func (w *watcher[T]) handleWatchResponse(wresp clientv3.WatchResponse) { + if wresp.Err() != nil { + err := w.handleWatchErr(wresp.Err()) + if err != nil { + log.Error("failed to handle watch session response", zap.Error(err)) + panic(err) + } + return + } + for _, ev := range wresp.Events { + session := &etcdSession{} + var eventType registry.SessionEventType + switch ev.Type { + case mvccpb.PUT: + log.Debug("watch services", + zap.Any("add kv", ev.Kv)) + err := json.Unmarshal(ev.Kv.Value, session) + if err != nil { + log.Error("watch services", zap.Error(err)) + continue + } + if !w.validate(session) { + continue + } + if session.stopping.Load() { + eventType = registry.SessionUpdateEvent + } else { + eventType = registry.SessionAddEvent + } + case mvccpb.DELETE: + log.Debug("watch services", + zap.Any("delete kv", ev.PrevKv)) + err := json.Unmarshal(ev.PrevKv.Value, session) + if err != nil { + log.Error("watch services", zap.Error(err)) + continue + } + if !w.validate(session) { + continue + } + eventType = registry.SessionDelEvent + } + log.Debug("WatchService", zap.Any("event type", eventType)) + client, err := w.convert(context.TODO(), session) + if err != nil { + log.Warn("failed to convert session entry to client", zap.Error(err)) + continue + } + w.eventCh <- registry.SessionEvent[T]{ + EventType: eventType, + Entry: session, + Client: client, + } + } +} + +func (w *watcher[T]) handleWatchErr(err error) error { + ctx := context.TODO() + // if not ErrCompacted, just close the channel + if err != v3rpc.ErrCompacted { + //close event channel + log.Warn("Watch service found error", zap.Error(err)) + close(w.eventCh) + return err + } + + _, revision, err := w.s.getServices(ctx, w.prefix) + if err != nil { + log.Warn("GetSession before rewatch failed", zap.String("prefix", w.prefix), zap.Error(err)) + close(w.eventCh) + return err + } + // rewatch is nil, no logic to handle + if w.rewatch == nil { + log.Warn("Watch service with ErrCompacted but no rewatch logic provided") + } else { + //TODO + /* + err = w.rewatch(sessions)*/ + } + if err != nil { + log.Warn("WatchServices rewatch failed", zap.String("prefix", w.prefix), zap.Error(err)) + close(w.eventCh) + return err + } + + w.rch = w.s.client.Watch(ctx, path.Join(w.s.metaRoot, milvuscommon.DefaultServiceRoot, w.prefix), clientv3.WithPrefix(), clientv3.WithPrevKV(), clientv3.WithRev(revision)) + return nil +} diff --git a/internal/registry/inproc/gen/gen.go b/internal/registry/inproc/gen/gen.go new file mode 100644 index 0000000000000..9394ccbfe5587 --- /dev/null +++ b/internal/registry/inproc/gen/gen.go @@ -0,0 +1,54 @@ +// This program generates internal/registry/inproc/wrapper_gen.go. Invoked by go generate +package main + +import ( + "fmt" + "html/template" + "os" +) + +var wrapperTpl = template.Must(template.New("").Delims("[[", "]]").Parse(`// Code generated by go generate; DO NOT EDIT +// This file is generated by go generate +package inproc + +import ( + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) +[[range .Services ]][[with .]] +type w[[.]] struct { + types.[[.]] +} + +func (w *w[[.]]) Set(v types.[[.]]) { + w.[[.]] = v +} + +func wrap[[.]](s *inProcServiceDiscovery, v types.[[.]]) types.[[.]] { + s.serviceMut.Lock() + defer s.serviceMut.Unlock() + w := &w[[.]]{[[.]]: v} + + s.wrappers[typeutil.[[.]]Role] = append(s.wrappers[typeutil.[[.]]Role], w) + return w +} +[[end]][[end]] +`)) + +func main() { + f, err := os.OpenFile("wrapper_gen.go", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755) + if err != nil { + fmt.Println(err.Error()) + return + } + defer f.Close() + params := map[string]any{ + "Services": []string{ + "RootCoord", + "DataCoord", + "QueryCoord", + }, + } + + wrapperTpl.Execute(f, params) +} diff --git a/internal/registry/inproc/service_discovery.go b/internal/registry/inproc/service_discovery.go new file mode 100644 index 0000000000000..c231c2e0cde60 --- /dev/null +++ b/internal/registry/inproc/service_discovery.go @@ -0,0 +1,197 @@ +package inproc + +import ( + "context" + "sync" + + "github.com/milvus-io/milvus/internal/registry" + "github.com/milvus-io/milvus/internal/registry/options" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type inProcServiceDiscovery struct { + inner registry.ServiceDiscovery + + serviceMut sync.RWMutex + // component type to service component instance + // all components shall be singleton in one process + services map[string]types.Component + // coordinator wrappers + wrappers map[string][]types.Component +} + +var _ registry.ServiceDiscovery = &inProcServiceDiscovery{} + +// NewInProcServiceDiscovery wraps other ServiceDiscovery component with extra inner process discovery logic. +func NewInProcServiceDiscovery(inner registry.ServiceDiscovery) registry.ServiceDiscovery { + return &inProcServiceDiscovery{ + inner: inner, + } +} + +func (s *inProcServiceDiscovery) GetServices(ctx context.Context, component string) ([]registry.ServiceEntry, error) { + return s.inner.GetServices(ctx, component) +} + +// WatchServices uses inner implmentation. +func (s *inProcServiceDiscovery) WatchServices(ctx context.Context, component string, opts ...options.WatchOption) ([]registry.ServiceEntry, registry.ServiceWatcher[registry.ServiceEntry], error) { + panic("not implemented") +} + +/* + Since the initialization sequence cannot guarantee that all coordinator could get other coord after they are ready, + inProc package shall wrap the coord client and the inproc alternative +*/ + +// GetRootCoord returns RootCoord grpc client as types.RootCoord. +func (s *inProcServiceDiscovery) GetRootCoord(ctx context.Context) (types.RootCoord, error) { + rc, ok := getComponent[types.RootCoord](s, typeutil.RootCoordRole) + if ok { + return rc, nil + } + + client, err := s.inner.GetRootCoord(ctx) + if err != nil { + return nil, err + } + return wrapRootCoord(s, client), nil +} + +// GetQueryCoord returns QueryCoord grpc client as types.QueryCoord. +func (s *inProcServiceDiscovery) GetQueryCoord(ctx context.Context) (types.QueryCoord, error) { + qc, ok := getComponent[types.QueryCoord](s, typeutil.QueryCoordRole) + if ok { + return qc, nil + } + + client, err := s.inner.GetQueryCoord(ctx) + if err != nil { + return nil, err + } + return wrapQueryCoord(s, client), nil +} + +// GetDataCoord returns DataCoord grpc client as types.DataCoord. +func (s *inProcServiceDiscovery) GetDataCoord(ctx context.Context) (types.DataCoord, error) { + qc, ok := getComponent[types.DataCoord](s, typeutil.DataCoordRole) + if ok { + return qc, nil + } + + client, err := s.inner.GetDataCoord(ctx) + if err != nil { + return nil, err + } + return wrapDataCoord(s, client), nil +} + +/* + Worker nodes initialized after corrdinators, so watchers shall listen to inproc.Regsiter events. +*/ +// WatchDataNode returns current DataNode instances and chan to watch. +func (s *inProcServiceDiscovery) WatchDataNode(ctx context.Context, opts ...options.WatchOption) ([]types.DataNode, registry.ServiceWatcher[types.DataNode], error) { + current, watcher, err := s.inner.WatchDataNode(ctx, opts...) + if err != nil { + return nil, nil, err + } + return current, newWatcher(s, watcher, typeutil.DataNodeRole), nil +} + +// WatchQueryNode returns current QueryNode instance and chan to watch. +func (s *inProcServiceDiscovery) WatchQueryNode(ctx context.Context, opts ...options.WatchOption) ([]types.QueryNode, registry.ServiceWatcher[types.QueryNode], error) { + current, watcher, err := s.inner.WatchQueryNode(ctx, opts...) + if err != nil { + return nil, nil, err + } + return current, newWatcher(s, watcher, typeutil.QueryNodeRole), nil +} + +func (s *inProcServiceDiscovery) WatchIndexNode(ctx context.Context, opts ...options.WatchOption) ([]types.IndexNode, registry.ServiceWatcher[types.IndexNode], error) { + current, watcher, err := s.inner.WatchIndexNode(ctx, opts...) + if err != nil { + return nil, nil, err + } + return current, newWatcher(s, watcher, typeutil.IndexNodeRole), nil +} + +func (s *inProcServiceDiscovery) WatchProxy(ctx context.Context, opts ...options.WatchOption) ([]types.Proxy, registry.ServiceWatcher[types.Proxy], error) { + current, watcher, err := s.inner.WatchProxy(ctx, opts...) + if err != nil { + return nil, nil, err + } + return current, newWatcher(s, watcher, typeutil.ProxyRole), nil +} + +func (s *inProcServiceDiscovery) RegisterRootCoord(ctx context.Context, rootcoord types.RootCoord, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return registerCoordinator[types.RootCoord, *wRootCoord](ctx, s, typeutil.RootCoordRole, rootcoord, addr, s.inner.RegisterRootCoord, opts...) +} + +func (s *inProcServiceDiscovery) RegisterDataCoord(ctx context.Context, datacoord types.DataCoord, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return registerCoordinator[types.DataCoord, *wDataCoord](ctx, s, typeutil.DataCoordRole, datacoord, addr, s.inner.RegisterDataCoord, opts...) +} + +func (s *inProcServiceDiscovery) RegisterQueryCoord(ctx context.Context, querycoord types.QueryCoord, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return registerCoordinator[types.QueryCoord, *wQueryCoord](ctx, s, typeutil.QueryCoordRole, querycoord, addr, s.inner.RegisterQueryCoord, opts...) +} + +func (s *inProcServiceDiscovery) RegisterProxy(ctx context.Context, proxy types.Proxy, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return registerNode(ctx, s, typeutil.ProxyRole, proxy, addr, s.inner.RegisterProxy, opts...) +} + +func (s *inProcServiceDiscovery) RegisterDataNode(ctx context.Context, datanode types.DataNode, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return registerNode(ctx, s, typeutil.DataNodeRole, datanode, addr, s.inner.RegisterDataNode, opts...) +} + +func (s *inProcServiceDiscovery) RegisterIndexNode(ctx context.Context, indexnode types.IndexNode, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return registerNode(ctx, s, typeutil.IndexNodeRole, indexnode, addr, s.inner.RegisterIndexNode, opts...) +} + +func (s *inProcServiceDiscovery) RegisterQueryNode(ctx context.Context, querynode types.QueryNode, addr string, opts ...options.RegisterOption) (registry.Session, error) { + return registerNode(ctx, s, typeutil.QueryNodeRole, querynode, addr, s.inner.RegisterQueryNode, opts...) +} + +func getComponent[T types.Component](s *inProcServiceDiscovery, component string) (T, bool) { + s.serviceMut.RLock() + defer s.serviceMut.RUnlock() + var result T + comp, ok := s.services[component] + if !ok { + return result, false + } + + return comp.(T), true +} + +func registerCoordinator[T types.Component, W interface{ Set(T) }](ctx context.Context, + s *inProcServiceDiscovery, component string, service T, addr string, + register func(context.Context, T, string, ...options.RegisterOption) (registry.Session, error), opts ...options.RegisterOption) (registry.Session, error) { + s.serviceMut.Lock() + defer s.serviceMut.Unlock() + session, err := register(ctx, service, addr, opts...) + if err != nil { + return nil, err + } + + s.services[component] = service + + for _, wrapper := range s.wrappers[component] { + w := wrapper.(W) + w.Set(service) + } + return session, err +} + +func registerNode[T types.Component](ctx context.Context, + s *inProcServiceDiscovery, component string, service T, addr string, + register func(context.Context, T, string, ...options.RegisterOption) (registry.Session, error), opts ...options.RegisterOption) (registry.Session, error) { + s.serviceMut.Lock() + defer s.serviceMut.Unlock() + session, err := register(ctx, service, addr, opts...) + if err != nil { + return nil, err + } + + s.services[component] = service + return session, err +} diff --git a/internal/registry/inproc/watcher.go b/internal/registry/inproc/watcher.go new file mode 100644 index 0000000000000..1ea004148f5f1 --- /dev/null +++ b/internal/registry/inproc/watcher.go @@ -0,0 +1,69 @@ +package inproc + +import ( + "github.com/milvus-io/milvus/internal/registry" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" +) + +type watcher[T types.Component] struct { + inner registry.ServiceWatcher[T] + + evtCh chan registry.SessionEvent[T] + closeCh chan struct{} + component string + s *inProcServiceDiscovery +} + +func newWatcher[T types.Component](s *inProcServiceDiscovery, inner registry.ServiceWatcher[T], component string) *watcher[T] { + w := &watcher[T]{ + inner: inner, + component: component, + s: s, + closeCh: make(chan struct{}), + evtCh: make(chan registry.SessionEvent[T], 100), + } + w.start() + return w +} + +func (w *watcher[T]) Watch() <-chan registry.SessionEvent[T] { + return w.inner.Watch() +} + +func (w *watcher[T]) Stop() { + w.inner.Stop() +} + +func (w *watcher[T]) start() { + go w.work() +} + +func (w *watcher[T]) work() { + for { + select { + case <-w.closeCh: + case evt, ok := <-w.inner.Watch(): + if !ok { + log.Warn("inproc watcher inner channel closed") + return + } + w.processEvt(evt) + } + } +} + +func (w *watcher[T]) processEvt(evt registry.SessionEvent[T]) { + // only apply replace client when session type is `Add` + if evt.EventType != registry.SessionAddEvent { + w.evtCh <- evt + return + } + service, ok := getComponent[T](w.s, w.component) + if ok { + //TODO add address check + evt.Client = service + } + + w.evtCh <- evt +} diff --git a/internal/registry/inproc/wrapper.go b/internal/registry/inproc/wrapper.go new file mode 100644 index 0000000000000..cc1150a7755c0 --- /dev/null +++ b/internal/registry/inproc/wrapper.go @@ -0,0 +1,14 @@ +package inproc + +//go:generate go run gen/gen.go + +// Designed implementation shall be following one, but golang cannot embed +// generic type T it self as anonymous field. +/* +type wrapper[T types.Component] struct { + T +} + +func (w *wrapper[T]) Set(v T) { + w.T = v +}*/ diff --git a/internal/registry/inproc/wrapper_gen.go b/internal/registry/inproc/wrapper_gen.go new file mode 100755 index 0000000000000..d61b78c80f2db --- /dev/null +++ b/internal/registry/inproc/wrapper_gen.go @@ -0,0 +1,59 @@ +// Code generated by go generate; DO NOT EDIT +// This file is generated by go generate +package inproc + +import ( + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type wRootCoord struct { + types.RootCoord +} + +func (w *wRootCoord) Set(v types.RootCoord) { + w.RootCoord = v +} + +func wrapRootCoord(s *inProcServiceDiscovery, v types.RootCoord) types.RootCoord { + s.serviceMut.Lock() + defer s.serviceMut.Unlock() + w := &wRootCoord{RootCoord: v} + + s.wrappers[typeutil.RootCoordRole] = append(s.wrappers[typeutil.RootCoordRole], w) + return w +} + +type wDataCoord struct { + types.DataCoord +} + +func (w *wDataCoord) Set(v types.DataCoord) { + w.DataCoord = v +} + +func wrapDataCoord(s *inProcServiceDiscovery, v types.DataCoord) types.DataCoord { + s.serviceMut.Lock() + defer s.serviceMut.Unlock() + w := &wDataCoord{DataCoord: v} + + s.wrappers[typeutil.DataCoordRole] = append(s.wrappers[typeutil.DataCoordRole], w) + return w +} + +type wQueryCoord struct { + types.QueryCoord +} + +func (w *wQueryCoord) Set(v types.QueryCoord) { + w.QueryCoord = v +} + +func wrapQueryCoord(s *inProcServiceDiscovery, v types.QueryCoord) types.QueryCoord { + s.serviceMut.Lock() + defer s.serviceMut.Unlock() + w := &wQueryCoord{QueryCoord: v} + + s.wrappers[typeutil.QueryCoordRole] = append(s.wrappers[typeutil.QueryCoordRole], w) + return w +} diff --git a/internal/registry/options/options.go b/internal/registry/options/options.go new file mode 100644 index 0000000000000..c1db330b9f726 --- /dev/null +++ b/internal/registry/options/options.go @@ -0,0 +1,53 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package options + +type SessionOpt struct { + StandBy bool + Exclusive bool + TriggerKill bool +} + +func DefaultSessionOpt() SessionOpt { + return SessionOpt{ + StandBy: false, + Exclusive: false, + TriggerKill: true, + } +} + +// RegisterOption used to setup session options when register services. +type RegisterOption func(*SessionOpt) + +// WithStandBy enables stand-by feature. +func WithStandBy() RegisterOption { + return func(opt *SessionOpt) { + opt.StandBy = true + } +} + +func WithTriggerKill(v bool) RegisterOption { + return func(opt *SessionOpt) { + opt.TriggerKill = v + } +} + +type watchOpt struct { +} + +// WatchOption used to setup watch services options. +type WatchOption func(*watchOpt) diff --git a/internal/registry/service_discovery.go b/internal/registry/service_discovery.go new file mode 100644 index 0000000000000..343b42b0bef6d --- /dev/null +++ b/internal/registry/service_discovery.go @@ -0,0 +1,110 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package registry + +import ( + "context" + + "github.com/milvus-io/milvus/internal/registry/options" + "github.com/milvus-io/milvus/internal/types" +) + +// ServiceEntry service entry definition. +type ServiceEntry interface { + ID() int64 + Addr() string + ComponentType() string +} + +// Session is the interface for service discovery. +type Session interface { + ServiceEntry + Revoke(ctx context.Context) error + Stop(ctx context.Context) error + Init(ctx context.Context) error + Register(ctx context.Context) error +} + +// ServiceDiscovery is the interface for service discovery operations. +type ServiceDiscovery interface { + // general get & watch, only for service entry + GetServices(ctx context.Context, component string) ([]ServiceEntry, error) + WatchServices(ctx context.Context, component string, opts ...options.WatchOption) ([]ServiceEntry, ServiceWatcher[ServiceEntry], error) + // Coordinators + GetRootCoord(ctx context.Context) (types.RootCoord, error) + GetQueryCoord(ctx context.Context) (types.QueryCoord, error) + GetDataCoord(ctx context.Context) (types.DataCoord, error) + // Watch methods + WatchDataNode(ctx context.Context, opts ...options.WatchOption) ([]types.DataNode, ServiceWatcher[types.DataNode], error) + WatchQueryNode(ctx context.Context, opts ...options.WatchOption) ([]types.QueryNode, ServiceWatcher[types.QueryNode], error) + WatchIndexNode(ctx context.Context, opts ...options.WatchOption) ([]types.IndexNode, ServiceWatcher[types.IndexNode], error) + WatchProxy(ctx context.Context, opts ...options.WatchOption) ([]types.Proxy, ServiceWatcher[types.Proxy], error) + // Register methods + RegisterRootCoord(ctx context.Context, rootcoord types.RootCoord, addr string, opts ...options.RegisterOption) (Session, error) + RegisterDataCoord(ctx context.Context, datacoord types.DataCoord, addr string, opts ...options.RegisterOption) (Session, error) + RegisterQueryCoord(ctx context.Context, querycoord types.QueryCoord, addr string, opts ...options.RegisterOption) (Session, error) + RegisterProxy(ctx context.Context, proxy types.Proxy, addr string, opts ...options.RegisterOption) (Session, error) + RegisterDataNode(ctx context.Context, datanode types.DataNode, addr string, opts ...options.RegisterOption) (Session, error) + RegisterIndexNode(ctx context.Context, indexnode types.IndexNode, addr string, opts ...options.RegisterOption) (Session, error) + RegisterQueryNode(ctx context.Context, index types.QueryNode, addr string, opts ...options.RegisterOption) (Session, error) +} + +// ServiceWatcher is the watch helper for ServiceDiscovery Watch methods. +type ServiceWatcher[T any] interface { + // Watch returns the channel of all the service change. + Watch() <-chan SessionEvent[T] + // Stop make watcher stops and closes all the returned channel. + Stop() +} + +// SessionEventType session event type +type SessionEventType int + +const ( + // SessionNoneEvent place holder for zero value + SessionNoneEvent SessionEventType = iota + // SessionAddEvent event type for a new Session Added + SessionAddEvent + // SessionDelEvent event type for a Session deleted + SessionDelEvent + // SessionUpdateEvent event type for a Session stopping + SessionUpdateEvent +) + +// String implements stringer. +func (t SessionEventType) String() string { + switch t { + case SessionAddEvent: + return "SessionAddEvent" + case SessionDelEvent: + return "SessionDelEvent" + case SessionUpdateEvent: + return "SessionUpdateEvent" + default: + return "" + } +} + +// SessionEvent indicates the changes of other servers. +// if a server is up, EventType is SessAddEvent. +// if a server is down, EventType is SessDelEvent. +// Session Saves the changed server's information. +type SessionEvent[T any] struct { + EventType SessionEventType + Entry ServiceEntry + Client T +} diff --git a/internal/util/grpcclient/auth.go b/internal/util/grpcclient/auth.go index 38043880fdaf0..4e30b9f4217e9 100644 --- a/internal/util/grpcclient/auth.go +++ b/internal/util/grpcclient/auth.go @@ -3,7 +3,7 @@ package grpcclient import ( "context" - "github.com/milvus-io/milvus/pkg/util" + // "github.com/milvus-io/milvus/pkg/util" ) type Token struct { @@ -11,7 +11,8 @@ type Token struct { } func (t *Token) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - return map[string]string{util.HeaderSourceID: t.Value}, nil + // return map[string]string{util.HeaderSourceID: t.Value}, nil + return map[string]string{}, nil } func (t *Token) RequireTransportSecurity() bool { diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 9f78e0a011b46..d66099f07a43d 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -50,13 +50,18 @@ import ( "github.com/milvus-io/milvus/pkg/util/retry" ) +// ServiceProvider service info provider +type ServiceProvider interface { + GetServiceEntry(ctx context.Context) (string, int64, error) +} + // GrpcClient abstracts client of grpc type GrpcClient[T interface { GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) }] interface { SetRole(string) GetRole() string - SetGetAddrFunc(func() (string, error)) + SetGetAddrFunc(func(context.Context) (string, error)) EnableEncryption() SetNewGrpcClientFunc(func(cc *grpc.ClientConn) T) GetGrpcClient(ctx context.Context) (T, error) @@ -65,14 +70,13 @@ type GrpcClient[T interface { Close() error SetNodeID(int64) GetNodeID() int64 - SetSession(sess *sessionutil.Session) } // ClientBase is a base of grpc client type ClientBase[T interface { GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) }] struct { - getAddrFunc func() (string, error) + getAddrFunc func(context.Context) (string, error) newGrpcClient func(cc *grpc.ClientConn) T grpcClient T @@ -144,7 +148,7 @@ func (c *ClientBase[T]) GetAddr() string { } // SetGetAddrFunc sets getAddrFunc of client -func (c *ClientBase[T]) SetGetAddrFunc(f func() (string, error)) { +func (c *ClientBase[T]) SetGetAddrFunc(f func(context.Context) (string, error)) { c.getAddrFunc = f } @@ -207,7 +211,7 @@ func (c *ClientBase[T]) resetConnection(client T) { } func (c *ClientBase[T]) connect(ctx context.Context) error { - addr, err := c.getAddrFunc() + addr, err := c.getAddrFunc(ctx) if err != nil { log.Ctx(ctx).Warn("failed to get client address", zap.Error(err)) return err @@ -519,11 +523,6 @@ func (c *ClientBase[T]) GetNodeID() int64 { return c.NodeID.Load() } -// SetSession set session role of client -func (c *ClientBase[T]) SetSession(sess *sessionutil.Session) { - c.sess = sess -} - func IsCrossClusterRoutingErr(err error) bool { // GRPC utilizes `status.Status` to encapsulate errors, // hence it is not viable to employ the `errors.Is` for assessment. diff --git a/internal/util/grpcclient/client_test.go b/internal/util/grpcclient/client_test.go index 2a8ce50550e02..7a06947841ddf 100644 --- a/internal/util/grpcclient/client_test.go +++ b/internal/util/grpcclient/client_test.go @@ -68,7 +68,7 @@ func TestClientBase_GetRole(t *testing.T) { func TestClientBase_connect(t *testing.T) { t.Run("failed to connect", func(t *testing.T) { base := ClientBase[*mockClient]{ - getAddrFunc: func() (string, error) { + getAddrFunc: func(_ context.Context) (string, error) { return "", nil }, DialTimeout: time.Millisecond, @@ -81,7 +81,7 @@ func TestClientBase_connect(t *testing.T) { t.Run("failed to get addr", func(t *testing.T) { errMock := errors.New("mocked") base := ClientBase[*mockClient]{ - getAddrFunc: func() (string, error) { + getAddrFunc: func(_ context.Context) (string, error) { return "", errMock }, DialTimeout: time.Millisecond, @@ -231,7 +231,7 @@ func testCall(t *testing.T, compressed bool) { base.grpcClientMtx.Lock() base.grpcClient = nil base.grpcClientMtx.Unlock() - base.SetGetAddrFunc(func() (string, error) { return "", nil }) + base.SetGetAddrFunc(func(context.Context) (string, error) { return "", nil }) t.Run("Call with connect failure", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -296,7 +296,7 @@ func TestClientBase_Recall(t *testing.T) { base.grpcClientMtx.Lock() base.grpcClient = nil base.grpcClientMtx.Unlock() - base.SetGetAddrFunc(func() (string, error) { return "", nil }) + base.SetGetAddrFunc(func(_ context.Context) (string, error) { return "", nil }) t.Run("ReCall with connect failure", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -367,7 +367,7 @@ func TestClientBase_RetryPolicy(t *testing.T) { MaxBackoff: 60.0, } clientBase.SetRole(typeutil.DataCoordRole) - clientBase.SetGetAddrFunc(func() (string, error) { + clientBase.SetGetAddrFunc(func(_ context.Context) (string, error) { return address.String(), nil }) clientBase.SetNewGrpcClientFunc(func(cc *grpc.ClientConn) rootcoordpb.RootCoordClient { @@ -430,7 +430,7 @@ func TestClientBase_Compression(t *testing.T) { CompressionEnabled: true, } clientBase.SetRole(typeutil.DataCoordRole) - clientBase.SetGetAddrFunc(func() (string, error) { + clientBase.SetGetAddrFunc(func(_ context.Context) (string, error) { return address.String(), nil }) clientBase.SetNewGrpcClientFunc(func(cc *grpc.ClientConn) rootcoordpb.RootCoordClient { diff --git a/internal/util/grpcclient/raw_provider.go b/internal/util/grpcclient/raw_provider.go new file mode 100644 index 0000000000000..a2ff124578d1b --- /dev/null +++ b/internal/util/grpcclient/raw_provider.go @@ -0,0 +1,71 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grpcclient + +import ( + "context" + "encoding/json" + "path" + + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" + clientv3 "go.etcd.io/etcd/client/v3" +) + +type rawEntryProvider struct { + client *clientv3.Client + metaPath string + role string +} + +// NewRawEntryProvider returns a RawEntryProvider for backward-compatibility. +// works for cooorinators service entry. +// Shall be removed when all service-discovery logic is unified. +func NewRawEntryProvider(cli *clientv3.Client, metaPath string, role string) ServiceProvider { + return &rawEntryProvider{ + client: cli, + metaPath: metaPath, + role: role, + } +} + +// GetServiceEntry returns current service entry information. +func (p *rawEntryProvider) GetServiceEntry(ctx context.Context) (string, int64, error) { + key := path.Join(p.metaPath, common.DefaultServiceRoot, p.role) + + resp, err := p.client.Get(ctx, key) + if err != nil { + err := merr.WrapErrIoFailed(key, err.Error()) + return "", 0, err + } + + if len(resp.Kvs) != 1 { + err := merr.WrapErrIoKeyNotFound(key) + return "", 0, err + } + + kv := resp.Kvs[0] + session := sessionutil.Session{} + err = json.Unmarshal(kv.Value, &session) + if err != nil { + err := merr.WrapErrIoKeyNotFound(key, err.Error()) + return "", 0, err + } + + return session.Address, session.ServerID, nil +} diff --git a/internal/util/mock/grpcclient.go b/internal/util/mock/grpcclient.go index b466f097c3759..7bf51bb8c9a06 100644 --- a/internal/util/mock/grpcclient.go +++ b/internal/util/mock/grpcclient.go @@ -33,7 +33,7 @@ import ( ) type GRPCClientBase[T any] struct { - getAddrFunc func() (string, error) + getAddrFunc func(context.Context) (string, error) newGrpcClient func(cc *grpc.ClientConn) T grpcClient T @@ -45,7 +45,7 @@ type GRPCClientBase[T any] struct { sess *sessionutil.Session } -func (c *GRPCClientBase[T]) SetGetAddrFunc(f func() (string, error)) { +func (c *GRPCClientBase[T]) SetGetAddrFunc(f func(context.Context) (string, error)) { c.getAddrFunc = f } @@ -164,7 +164,3 @@ func (c *GRPCClientBase[T]) GetNodeID() int64 { func (c *GRPCClientBase[T]) SetNodeID(nodeID int64) { c.nodeID = nodeID } - -func (c *GRPCClientBase[T]) SetSession(sess *sessionutil.Session) { - c.sess = sess -} diff --git a/pkg/common/common.go b/pkg/common/common.go index 1ec3bfe3bd51d..c5beea97a5bbc 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -81,6 +81,14 @@ const ( SegmentIndexPath = `index_files` ) +// Session +const ( + // DefaultServiceRoot default root path used in kv by Session + DefaultServiceRoot = "session/" + // DefaultIDKey default id key for Session + DefaultIDKey = "id" +) + // Search, Index parameter keys const ( TopKKey = "topk"