diff --git a/rabbitmq/kubedb_client_builder.go b/rabbitmq/kubedb_client_builder.go index 87f6fef2..e3bbd1e4 100644 --- a/rabbitmq/kubedb_client_builder.go +++ b/rabbitmq/kubedb_client_builder.go @@ -18,9 +18,14 @@ package rabbitmq import ( "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" + "net" + "net/http" "strings" + "time" rmqhttp "github.com/michaelklishin/rabbit-hole/v2" amqp "github.com/rabbitmq/amqp091-go" @@ -29,6 +34,7 @@ import ( "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "kubedb.dev/apimachinery/apis/kubedb" + dbapi "kubedb.dev/apimachinery/apis/kubedb/v1" olddbapi "kubedb.dev/apimachinery/apis/kubedb/v1alpha2" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -129,6 +135,43 @@ func (o *KubeDBClientBuilder) GetRabbitMQClient() (*Client, error) { username, password = "guest", "guest" } + var tlsConfig *tls.Config + if o.db.Spec.EnableSSL { + certSecret := &core.Secret{} + err := o.kc.Get(o.ctx, types.NamespacedName{ + Namespace: o.db.Namespace, + Name: o.db.GetCertSecretName(olddbapi.RabbitmqClientCert), + }, certSecret) + if err != nil { + if kerr.IsNotFound(err) { + klog.Error(err, "Client certificate secret not found") + return nil, errors.New("client certificate secret is not found") + } + klog.Error(err, "Failed to get client certificate Secret") + return nil, err + } + + // get tls cert, clientCA and rootCA for tls config + clientCA := x509.NewCertPool() + rootCA := x509.NewCertPool() + + crt, err := tls.X509KeyPair(certSecret.Data[core.TLSCertKey], certSecret.Data[core.TLSPrivateKeyKey]) + if err != nil { + klog.Error(err, "Failed to parse private key pair") + return nil, err + } + clientCA.AppendCertsFromPEM(certSecret.Data[dbapi.TLSCACertFileName]) + rootCA.AppendCertsFromPEM(certSecret.Data[dbapi.TLSCACertFileName]) + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{crt}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: clientCA, + RootCAs: rootCA, + MaxVersion: tls.VersionTLS13, + } + } + rmqClient := &Client{} defaultVhost := "/" @@ -136,7 +179,19 @@ func (o *KubeDBClientBuilder) GetRabbitMQClient() (*Client, error) { if o.httpURL == "" { o.httpURL = o.GetHTTPconnURL() } - httpClient, err := rmqhttp.NewClient(o.httpURL, username, password) + httpClient, err := func(isTLSEnabled bool) (*rmqhttp.Client, error) { + if isTLSEnabled { + return rmqhttp.NewTLSClient(o.httpURL, username, password, &http.Transport{ + IdleConnTimeout: time.Second * 3, + DialContext: (&net.Dialer{ + Timeout: time.Second * 30, + }).DialContext, + TLSClientConfig: tlsConfig, + TLSHandshakeTimeout: time.Second * 30, + }) + } + return rmqhttp.NewClient(o.httpURL, username, password) + }(o.db.Spec.EnableSSL) if err != nil { klog.Error(err, "Failed to get http client for rabbitmq") return nil, err @@ -184,11 +239,18 @@ func (o *KubeDBClientBuilder) GetAMQPconnURL(username string, password string, v } func (o *KubeDBClientBuilder) GetHTTPconnURL() string { - protocolScheme := rmqhttp.HTTP + protocolScheme := o.db.GetConnectionScheme() + connectionPort := func(scheme string) int { + if scheme == "http" { + return kubedb.RabbitMQManagementUIPort + } else { + return kubedb.RabbitMQManagementUIPortWithSSL + } + }(protocolScheme) if o.podName != "" { - return fmt.Sprintf("%s://%s.%s.%s.svc:%d", protocolScheme, o.podName, o.db.GoverningServiceName(), o.db.Namespace, kubedb.RabbitMQManagementUIPort) + return fmt.Sprintf("%s://%s.%s.%s.svc:%d", protocolScheme, o.podName, o.db.GoverningServiceName(), o.db.Namespace, connectionPort) } - return fmt.Sprintf("%s://%s.%s.svc.cluster.local:%d", protocolScheme, o.db.DashboardServiceName(), o.db.Namespace, kubedb.RabbitMQManagementUIPort) + return fmt.Sprintf("%s://%s.%s.svc.cluster.local:%d", protocolScheme, o.db.DashboardServiceName(), o.db.Namespace, connectionPort) } // RabbitMQ server have a default virtual host "/"