diff --git a/cmd/tools/config/generate.go b/cmd/tools/config/generate.go index f09925c573206..dc93e68b379eb 100644 --- a/cmd/tools/config/generate.go +++ b/cmd/tools/config/generate.go @@ -298,11 +298,11 @@ func WriteYaml(w io.Writer) { }, { name: "tls", - header: "\n# Configure the proxy tls enable.", + header: "\n# Configure external tls.", }, { name: "internaltls", - header: "\n# Configure the node-tls enable.", + header: "\n# Configure internal tls.", }, { name: "common", diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 7ff8b898ae1aa..6cbb6a64c3fdd 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -778,17 +778,18 @@ grpc: maxCancelError: 32 minSessionCheckInterval: 200 -# Configure the proxy tls enable. +# Configure external tls. tls: serverPemPath: configs/cert/server.pem serverKeyPath: configs/cert/server.key caPemPath: configs/cert/ca.pem -# Configure the node-tls enable. +# Configure internal tls. internaltls: serverPemPath: configs/cert/server.pem serverKeyPath: configs/cert/server.key caPemPath: configs/cert/ca.pem + sni: localhost # The server name indication (SNI) for internal TLS, should be the same as the name provided by the certificates ref: https://en.wikipedia.org/wiki/Server_Name_Indication common: defaultPartitionName: _default # Name of the default partition when a collection is created diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index bb095cdae30b0..af25c24ae02d7 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -80,6 +80,7 @@ func NewClient(ctx context.Context) (types.DataCoordClient, error) { return nil, err } client.grpcClient.SetInternalTLSCertPool(cp) + client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue()) } return client, nil } diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index 5859e7ee33e5c..a189266781844 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -81,6 +81,7 @@ func NewClient(ctx context.Context, addr string, serverID int64) (types.DataNode return nil, err } client.grpcClient.SetInternalTLSCertPool(cp) + client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue()) } return client, nil } diff --git a/internal/distributed/indexnode/client/client.go b/internal/distributed/indexnode/client/client.go index 7387bdb1385e3..cd4d29780ba05 100644 --- a/internal/distributed/indexnode/client/client.go +++ b/internal/distributed/indexnode/client/client.go @@ -81,6 +81,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64, encryption bool) return nil, err } client.grpcClient.SetInternalTLSCertPool(cp) + client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue()) } return client, nil } diff --git a/internal/distributed/proxy/client/client.go b/internal/distributed/proxy/client/client.go index ffbc91ab20ca3..5c5cbd50838c8 100644 --- a/internal/distributed/proxy/client/client.go +++ b/internal/distributed/proxy/client/client.go @@ -78,6 +78,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.ProxyClien return nil, err } client.grpcClient.SetInternalTLSCertPool(cp) + client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue()) } return client, nil } diff --git a/internal/distributed/querycoord/client/client.go b/internal/distributed/querycoord/client/client.go index 867d73e7f6a87..db0b61a901250 100644 --- a/internal/distributed/querycoord/client/client.go +++ b/internal/distributed/querycoord/client/client.go @@ -72,6 +72,7 @@ func NewClient(ctx context.Context) (types.QueryCoordClient, error) { return nil, err } client.grpcClient.SetInternalTLSCertPool(cp) + client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue()) } return client, nil } diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index 7dfe4dc8be62a..1ee1161d5afd9 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -81,6 +81,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.QueryNodeC return nil, err } client.grpcClient.SetInternalTLSCertPool(cp) + client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue()) } return client, nil } diff --git a/internal/distributed/rootcoord/client/client.go b/internal/distributed/rootcoord/client/client.go index 6d0c871042366..98da081920106 100644 --- a/internal/distributed/rootcoord/client/client.go +++ b/internal/distributed/rootcoord/client/client.go @@ -79,6 +79,7 @@ func NewClient(ctx context.Context) (types.RootCoordClient, error) { return nil, err } client.grpcClient.SetInternalTLSCertPool(cp) + client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue()) } return client, nil } diff --git a/internal/mocks/mock_grpc_client.go b/internal/mocks/mock_grpc_client.go index 4bd3c96f7de2f..4d722f96ef1fa 100644 --- a/internal/mocks/mock_grpc_client.go +++ b/internal/mocks/mock_grpc_client.go @@ -360,6 +360,40 @@ func (_c *MockGrpcClient_SetInternalTLSCertPool_Call[T]) RunAndReturn(run func(* return _c } + +// SetInternalTLSServerName provides a mock function with given fields: cp +func (_m *MockGrpcClient[T]) SetInternalTLSServerName(cp string) { + _m.Called(cp) +} + +// MockGrpcClient_SetInternalTLSServerName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetInternalTLSServerName' +type MockGrpcClient_SetInternalTLSServerName_Call[T grpcclient.GrpcComponent] struct { + *mock.Call +} + +// SetInternalTLSServerName is a helper method to define mock.On call +// - cp string +func (_e *MockGrpcClient_Expecter[T]) SetInternalTLSServerName(cp interface{}) *MockGrpcClient_SetInternalTLSServerName_Call[T] { + return &MockGrpcClient_SetInternalTLSServerName_Call[T]{Call: _e.mock.On("SetInternalTLSServerName", cp)} +} + +func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) Run(run func(cp string)) *MockGrpcClient_SetInternalTLSServerName_Call[T] { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) Return() *MockGrpcClient_SetInternalTLSServerName_Call[T] { + _c.Call.Return() + return _c +} + +func (_c *MockGrpcClient_SetInternalTLSServerName_Call[T]) RunAndReturn(run func(string)) *MockGrpcClient_SetInternalTLSServerName_Call[T] { + _c.Call.Return(run) + return _c +} + // SetNewGrpcClientFunc provides a mock function with given fields: _a0 func (_m *MockGrpcClient[T]) SetNewGrpcClientFunc(_a0 func(*grpc.ClientConn) T) { _m.Called(_a0) diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 8927b44d207b1..57a0209b5bed6 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -86,6 +86,7 @@ type GrpcClient[T GrpcComponent] interface { SetGetAddrFunc(func() (string, error)) EnableEncryption() SetInternalTLSCertPool(cp *x509.CertPool) + SetInternalTLSServerName(cp string) SetNewGrpcClientFunc(func(cc *grpc.ClientConn) T) ReCall(ctx context.Context, caller func(client T) (any, error)) (any, error) Call(ctx context.Context, caller func(client T) (any, error)) (any, error) @@ -103,10 +104,12 @@ type ClientBase[T interface { newGrpcClient func(cc *grpc.ClientConn) T // grpcClient T - grpcClient *clientConnWrapper[T] - encryption bool - cpInternalTLS *x509.CertPool - addr atomic.String + grpcClient *clientConnWrapper[T] + encryption bool + cpInternalTLS *x509.CertPool + addr atomic.String + internalTLSServerName string + // conn *grpc.ClientConn grpcClientMtx sync.RWMutex role string @@ -194,6 +197,10 @@ func (c *ClientBase[T]) SetInternalTLSCertPool(cp *x509.CertPool) { c.cpInternalTLS = cp } +func (c *ClientBase[T]) SetInternalTLSServerName(cp string) { + c.internalTLSServerName = cp +} + // SetNewGrpcClientFunc sets newGrpcClient of client func (c *ClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) { c.newGrpcClient = f @@ -269,7 +276,12 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { dialContext, addr, // #nosec G402 - grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{RootCAs: c.cpInternalTLS})), + grpc.WithTransportCredentials(credentials.NewTLS( + &tls.Config{ + RootCAs: c.cpInternalTLS, + ServerName: c.internalTLSServerName, + }, + )), grpc.WithBlock(), grpc.WithDefaultCallOptions( grpc.MaxCallRecvMsgSize(c.ClientMaxRecvSize), diff --git a/internal/util/mock/grpcclient.go b/internal/util/mock/grpcclient.go index 4da362c8815af..448f014f15aa8 100644 --- a/internal/util/mock/grpcclient.go +++ b/internal/util/mock/grpcclient.go @@ -39,6 +39,7 @@ type GRPCClientBase[T any] struct { grpcClient T cpInternalTLS *x509.CertPool + cpInternalSNI string conn *grpc.ClientConn grpcClientMtx sync.RWMutex GetGrpcClientErr error @@ -66,6 +67,10 @@ func (c *GRPCClientBase[T]) SetInternalTLSCertPool(cp *x509.CertPool) { c.cpInternalTLS = cp } +func (c *GRPCClientBase[T]) SetInternalTLSServerName(cp string) { + c.cpInternalSNI = cp +} + func (c *GRPCClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) { c.newGrpcClient = f } diff --git a/pkg/util/paramtable/grpc_param.go b/pkg/util/paramtable/grpc_param.go index 53e394c07ccb8..b5b271b1b18a0 100644 --- a/pkg/util/paramtable/grpc_param.go +++ b/pkg/util/paramtable/grpc_param.go @@ -541,12 +541,13 @@ type InternalTLSConfig struct { InternalTLSServerPemPath ParamItem `refreshable:"false"` InternalTLSServerKeyPath ParamItem `refreshable:"false"` InternalTLSCaPemPath ParamItem `refreshable:"false"` + InternalTLSSNI ParamItem `refreshable:"false"` } func (p *InternalTLSConfig) Init(base *BaseTable) { p.InternalTLSEnabled = ParamItem{ Key: "common.security.internaltlsEnabled", - Version: "2.0.0", + Version: "2.5.0", DefaultValue: "false", Export: true, } @@ -554,22 +555,30 @@ func (p *InternalTLSConfig) Init(base *BaseTable) { p.InternalTLSServerPemPath = ParamItem{ Key: "internaltls.serverPemPath", - Version: "2.0.0", + Version: "2.5.0", Export: true, } p.InternalTLSServerPemPath.Init(base.mgr) p.InternalTLSServerKeyPath = ParamItem{ Key: "internaltls.serverKeyPath", - Version: "2.0.0", + Version: "2.5.0", Export: true, } p.InternalTLSServerKeyPath.Init(base.mgr) p.InternalTLSCaPemPath = ParamItem{ Key: "internaltls.caPemPath", - Version: "2.0.0", + Version: "2.5.0", Export: true, } p.InternalTLSCaPemPath.Init(base.mgr) + + p.InternalTLSSNI = ParamItem{ + Key: "internaltls.sni", + Version: "2.5.0", + Export: true, + Doc: "The server name indication (SNI) for internal TLS, should be the same as the name provided by the certificates ref: https://en.wikipedia.org/wiki/Server_Name_Indication", + } + p.InternalTLSSNI.Init(base.mgr) } diff --git a/pkg/util/paramtable/grpc_param_test.go b/pkg/util/paramtable/grpc_param_test.go index bf07dfeaa98e7..df5014de16e47 100644 --- a/pkg/util/paramtable/grpc_param_test.go +++ b/pkg/util/paramtable/grpc_param_test.go @@ -189,8 +189,10 @@ func TestInternalTLSParams(t *testing.T) { base.Save("internaltls.serverPemPath", "/pem") base.Save("internaltls.serverKeyPath", "/key") base.Save("internaltls.caPemPath", "/ca") + base.Save("internaltls.sni", "localhost") assert.Equal(t, internalTLSCfg.InternalTLSEnabled.GetAsBool(), true) assert.Equal(t, internalTLSCfg.InternalTLSServerPemPath.GetValue(), "/pem") assert.Equal(t, internalTLSCfg.InternalTLSServerKeyPath.GetValue(), "/key") assert.Equal(t, internalTLSCfg.InternalTLSCaPemPath.GetValue(), "/ca") + assert.Equal(t, internalTLSCfg.InternalTLSSNI.GetValue(), "localhost") } diff --git a/tests/integration/internaltls/internaltls_test.go b/tests/integration/internaltls/internaltls_test.go index 9967051f02ad6..39d6dd5940520 100644 --- a/tests/integration/internaltls/internaltls_test.go +++ b/tests/integration/internaltls/internaltls_test.go @@ -50,27 +50,6 @@ type InternaltlsTestSuit struct { // Define the content for the configuration YAML file var configContent = ` -rootCoord: - ip: localhost - -proxy: - ip: localhost - -queryCoord: - ip: localhost - -queryNode: - ip: localhost - -indexNode: - ip: localhost - -dataCoord: - ip: localhost - -dataNode: - ip: localhost - common: security: internaltlsEnabled : true @@ -79,6 +58,7 @@ internaltls: serverPemPath: ../../../configs/cert/server.pem serverKeyPath: ../../../configs/cert/server.key caPemPath: ../../../configs/cert/ca.pem + sni: localhost ` const configFilePath = "../../../configs/_test.yaml"