Skip to content

Commit

Permalink
feat: add config field to set internal tls sni (#38124)
Browse files Browse the repository at this point in the history
/cc @xiaofan-luan @jaime0815 @nish112022

part of #36864

Signed-off-by: haorenfsa <[email protected]>
  • Loading branch information
haorenfsa authored Dec 4, 2024
1 parent cbba296 commit 1f66b9e
Show file tree
Hide file tree
Showing 15 changed files with 84 additions and 34 deletions.
4 changes: 2 additions & 2 deletions cmd/tools/config/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,11 @@ func WriteYaml(w io.Writer) {
},
{
name: "tls",
header: "\n# Configure the proxy tls enable.",
header: "\n# Configure external tls.",
},
{
name: "internaltls",
header: "\n# Configure the node-tls enable.",
header: "\n# Configure internal tls.",
},
{
name: "common",
Expand Down
5 changes: 3 additions & 2 deletions configs/milvus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -778,17 +778,18 @@ grpc:
maxCancelError: 32
minSessionCheckInterval: 200

# Configure the proxy tls enable.
# Configure external tls.
tls:
serverPemPath: configs/cert/server.pem
serverKeyPath: configs/cert/server.key
caPemPath: configs/cert/ca.pem

# Configure the node-tls enable.
# Configure internal tls.
internaltls:
serverPemPath: configs/cert/server.pem
serverKeyPath: configs/cert/server.key
caPemPath: configs/cert/ca.pem
sni: localhost # The server name indication (SNI) for internal TLS, should be the same as the name provided by the certificates ref: https://en.wikipedia.org/wiki/Server_Name_Indication

common:
defaultPartitionName: _default # Name of the default partition when a collection is created
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/datacoord/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func NewClient(ctx context.Context) (types.DataCoordClient, error) {
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/datanode/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func NewClient(ctx context.Context, addr string, serverID int64) (types.DataNode
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/indexnode/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64, encryption bool)
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/proxy/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.ProxyClien
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/querycoord/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func NewClient(ctx context.Context) (types.QueryCoordClient, error) {
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/querynode/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.QueryNodeC
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/rootcoord/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func NewClient(ctx context.Context) (types.RootCoordClient, error) {
return nil, err
}
client.grpcClient.SetInternalTLSCertPool(cp)
client.grpcClient.SetInternalTLSServerName(Params.InternalTLSCfg.InternalTLSSNI.GetValue())
}
return client, nil
}
Expand Down
34 changes: 34 additions & 0 deletions internal/mocks/mock_grpc_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 17 additions & 5 deletions internal/util/grpcclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ type GrpcClient[T GrpcComponent] interface {
SetGetAddrFunc(func() (string, error))
EnableEncryption()
SetInternalTLSCertPool(cp *x509.CertPool)
SetInternalTLSServerName(cp string)
SetNewGrpcClientFunc(func(cc *grpc.ClientConn) T)
ReCall(ctx context.Context, caller func(client T) (any, error)) (any, error)
Call(ctx context.Context, caller func(client T) (any, error)) (any, error)
Expand All @@ -103,10 +104,12 @@ type ClientBase[T interface {
newGrpcClient func(cc *grpc.ClientConn) T

// grpcClient T
grpcClient *clientConnWrapper[T]
encryption bool
cpInternalTLS *x509.CertPool
addr atomic.String
grpcClient *clientConnWrapper[T]
encryption bool
cpInternalTLS *x509.CertPool
addr atomic.String
internalTLSServerName string

// conn *grpc.ClientConn
grpcClientMtx sync.RWMutex
role string
Expand Down Expand Up @@ -194,6 +197,10 @@ func (c *ClientBase[T]) SetInternalTLSCertPool(cp *x509.CertPool) {
c.cpInternalTLS = cp
}

func (c *ClientBase[T]) SetInternalTLSServerName(cp string) {
c.internalTLSServerName = cp
}

// SetNewGrpcClientFunc sets newGrpcClient of client
func (c *ClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) {
c.newGrpcClient = f
Expand Down Expand Up @@ -269,7 +276,12 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
dialContext,
addr,
// #nosec G402
grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{RootCAs: c.cpInternalTLS})),
grpc.WithTransportCredentials(credentials.NewTLS(
&tls.Config{
RootCAs: c.cpInternalTLS,
ServerName: c.internalTLSServerName,
},
)),
grpc.WithBlock(),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(c.ClientMaxRecvSize),
Expand Down
5 changes: 5 additions & 0 deletions internal/util/mock/grpcclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type GRPCClientBase[T any] struct {

grpcClient T
cpInternalTLS *x509.CertPool
cpInternalSNI string
conn *grpc.ClientConn
grpcClientMtx sync.RWMutex
GetGrpcClientErr error
Expand Down Expand Up @@ -66,6 +67,10 @@ func (c *GRPCClientBase[T]) SetInternalTLSCertPool(cp *x509.CertPool) {
c.cpInternalTLS = cp
}

func (c *GRPCClientBase[T]) SetInternalTLSServerName(cp string) {
c.cpInternalSNI = cp
}

func (c *GRPCClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) {
c.newGrpcClient = f
}
Expand Down
17 changes: 13 additions & 4 deletions pkg/util/paramtable/grpc_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,35 +541,44 @@ type InternalTLSConfig struct {
InternalTLSServerPemPath ParamItem `refreshable:"false"`
InternalTLSServerKeyPath ParamItem `refreshable:"false"`
InternalTLSCaPemPath ParamItem `refreshable:"false"`
InternalTLSSNI ParamItem `refreshable:"false"`
}

func (p *InternalTLSConfig) Init(base *BaseTable) {
p.InternalTLSEnabled = ParamItem{
Key: "common.security.internaltlsEnabled",
Version: "2.0.0",
Version: "2.5.0",
DefaultValue: "false",
Export: true,
}
p.InternalTLSEnabled.Init(base.mgr)

p.InternalTLSServerPemPath = ParamItem{
Key: "internaltls.serverPemPath",
Version: "2.0.0",
Version: "2.5.0",
Export: true,
}
p.InternalTLSServerPemPath.Init(base.mgr)

p.InternalTLSServerKeyPath = ParamItem{
Key: "internaltls.serverKeyPath",
Version: "2.0.0",
Version: "2.5.0",
Export: true,
}
p.InternalTLSServerKeyPath.Init(base.mgr)

p.InternalTLSCaPemPath = ParamItem{
Key: "internaltls.caPemPath",
Version: "2.0.0",
Version: "2.5.0",
Export: true,
}
p.InternalTLSCaPemPath.Init(base.mgr)

p.InternalTLSSNI = ParamItem{
Key: "internaltls.sni",
Version: "2.5.0",
Export: true,
Doc: "The server name indication (SNI) for internal TLS, should be the same as the name provided by the certificates ref: https://en.wikipedia.org/wiki/Server_Name_Indication",
}
p.InternalTLSSNI.Init(base.mgr)
}
2 changes: 2 additions & 0 deletions pkg/util/paramtable/grpc_param_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,10 @@ func TestInternalTLSParams(t *testing.T) {
base.Save("internaltls.serverPemPath", "/pem")
base.Save("internaltls.serverKeyPath", "/key")
base.Save("internaltls.caPemPath", "/ca")
base.Save("internaltls.sni", "localhost")
assert.Equal(t, internalTLSCfg.InternalTLSEnabled.GetAsBool(), true)
assert.Equal(t, internalTLSCfg.InternalTLSServerPemPath.GetValue(), "/pem")
assert.Equal(t, internalTLSCfg.InternalTLSServerKeyPath.GetValue(), "/key")
assert.Equal(t, internalTLSCfg.InternalTLSCaPemPath.GetValue(), "/ca")
assert.Equal(t, internalTLSCfg.InternalTLSSNI.GetValue(), "localhost")
}
22 changes: 1 addition & 21 deletions tests/integration/internaltls/internaltls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,6 @@ type InternaltlsTestSuit struct {

// Define the content for the configuration YAML file
var configContent = `
rootCoord:
ip: localhost
proxy:
ip: localhost
queryCoord:
ip: localhost
queryNode:
ip: localhost
indexNode:
ip: localhost
dataCoord:
ip: localhost
dataNode:
ip: localhost
common:
security:
internaltlsEnabled : true
Expand All @@ -79,6 +58,7 @@ internaltls:
serverPemPath: ../../../configs/cert/server.pem
serverKeyPath: ../../../configs/cert/server.key
caPemPath: ../../../configs/cert/ca.pem
sni: localhost
`

const configFilePath = "../../../configs/_test.yaml"
Expand Down

0 comments on commit 1f66b9e

Please sign in to comment.