From 92a4f8bd329cd7f82aab2e3bc723a5c0158dca6a Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 4 Dec 2024 16:03:41 +0300 Subject: [PATCH 1/4] NOISSUE - Fix loading of CA certs on agent (#321) * debug connection Signed-off-by: Sammy Oina * actual fix Signed-off-by: Sammy Oina * remove debugs Signed-off-by: Sammy Oina * remove test Signed-off-by: Sammy Oina * add unit test Signed-off-by: Sammy Oina * more tests Signed-off-by: Sammy Oina * consolidate tests Signed-off-by: Sammy Oina * fix client auth Signed-off-by: Sammy Oina * debug Signed-off-by: Sammy Oina * better handling Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- cli/sdk.go | 4 +- cmd/cli/main.go | 2 +- internal/server/grpc/grpc.go | 36 +++-- internal/server/grpc/grpc_test.go | 259 +++++++++++++++++++++++++----- pkg/clients/grpc/agent/agent.go | 2 +- pkg/clients/grpc/connect.go | 3 +- 6 files changed, 246 insertions(+), 60 deletions(-) diff --git a/cli/sdk.go b/cli/sdk.go index f4099115..aa545a2c 100644 --- a/cli/sdk.go +++ b/cli/sdk.go @@ -5,6 +5,7 @@ package cli import ( "context" + "github.com/spf13/cobra" "github.com/ultravioletrs/cocos/pkg/clients/grpc" "github.com/ultravioletrs/cocos/pkg/clients/grpc/agent" "github.com/ultravioletrs/cocos/pkg/sdk" @@ -25,12 +26,13 @@ func New(config grpc.Config) *CLI { } } -func (c *CLI) InitializeSDK() error { +func (c *CLI) InitializeSDK(cmd *cobra.Command) error { agentGRPCClient, agentClient, err := agent.NewAgentClient(context.Background(), c.config) if err != nil { c.connectErr = err return err } + cmd.Println("🔗 Connected to agent using ", agentGRPCClient.Secure()) c.client = agentGRPCClient c.agentSDK = sdk.NewAgentSDK(agentClient) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index c67ca252..63e76575 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -100,7 +100,7 @@ func main() { cliSVC := cli.New(agentGRPCConfig) - if err := cliSVC.InitializeSDK(); err == nil { + if err := cliSVC.InitializeSDK(rootCmd); err == nil { defer cliSVC.Close() } diff --git a/internal/server/grpc/grpc.go b/internal/server/grpc/grpc.go index f5f5a854..c3ce35a4 100644 --- a/internal/server/grpc/grpc.go +++ b/internal/server/grpc/grpc.go @@ -127,7 +127,7 @@ func (s *Server) Start() error { return fmt.Errorf("failed to load auth certificates: %w", err) } tlsConfig := &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, + ClientAuth: tls.NoClientCert, Certificates: []tls.Certificate{certificate}, } @@ -161,12 +161,17 @@ func (s *Server) Start() error { } mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile) } + + if mtlsCA != "" { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + creds = grpc.Creds(credentials.NewTLS(tlsConfig)) switch { case mtlsCA != "": - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA)) + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS", s.Name, s.Address)) default: - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile)) + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS", s.Name, s.Address)) } listener, err = net.Listen("tcp", s.Address) @@ -223,31 +228,28 @@ func (s *Server) Stop() error { func loadCertFile(certFile string) ([]byte, error) { if certFile != "" { - return os.ReadFile(certFile) + return readFileOrData(certFile) } return []byte{}, nil } -func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) { - var cert, key []byte - var err error - - readFileOrData := func(input string) ([]byte, error) { - if len(input) < 1000 && !strings.Contains(input, "\n") { - data, err := os.ReadFile(input) - if err == nil { - return data, nil - } +func readFileOrData(input string) ([]byte, error) { + if len(input) < 1000 && !strings.Contains(input, "\n") { + data, err := os.ReadFile(input) + if err == nil { + return data, nil } - return []byte(input), nil } + return []byte(input), nil +} - cert, err = readFileOrData(certfile) +func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) { + cert, err := readFileOrData(certfile) if err != nil { return tls.Certificate{}, fmt.Errorf("failed to read cert: %v", err) } - key, err = readFileOrData(keyfile) + key, err := readFileOrData(keyfile) if err != nil { return tls.Certificate{}, fmt.Errorf("failed to read key: %v", err) } diff --git a/internal/server/grpc/grpc_test.go b/internal/server/grpc/grpc_test.go index b6b1c39b..248980e4 100644 --- a/internal/server/grpc/grpc_test.go +++ b/internal/server/grpc/grpc_test.go @@ -12,6 +12,7 @@ import ( "fmt" "log/slog" "math/big" + "os" "strings" "sync" "testing" @@ -51,49 +52,39 @@ func TestNew(t *testing.T) { assert.IsType(t, &Server{}, srv) } -func TestServerStart(t *testing.T) { +func TestServerStartWithTLSFile(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - config := server.Config{ - Host: "localhost", - Port: "0", - } - buf := &ThreadSafeBuffer{} - logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) - authSvc := new(authmocks.Authenticator) - - srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - wg.Done() - err := srv.Start() - assert.NoError(t, err) - }() + cert, key, err := generateSelfSignedCert() + assert.NoError(t, err) - wg.Wait() + certFile, err := os.CreateTemp("", "cert*.pem") + assert.NoError(t, err) - time.Sleep(100 * time.Millisecond) + keyFile, err := os.CreateTemp("", "key*.pem") + assert.NoError(t, err) - cancel() + t.Cleanup(func() { + os.Remove(certFile.Name()) + os.Remove(keyFile.Name()) + }) - assert.Contains(t, buf.String(), "TestServer service gRPC server listening at localhost:0 without TLS") -} + _, err = certFile.Write(cert) + assert.NoError(t, err) -func TestServerStartWithTLS(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + _, err = keyFile.Write(key) + assert.NoError(t, err) - cert, key, err := generateSelfSignedCert() + err = certFile.Close() + assert.NoError(t, err) + err = keyFile.Close() assert.NoError(t, err) config := server.Config{ Host: "localhost", Port: "0", - CertFile: string(cert), - KeyFile: string(key), + CertFile: certFile.Name(), + KeyFile: keyFile.Name(), } logBuffer := &ThreadSafeBuffer{} @@ -125,13 +116,41 @@ func TestServerStartWithTLS(t *testing.T) { assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") } -func TestServerStartWithAttestedTLS(t *testing.T) { +func TestServerStartWithmTLSFile(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) + cert, key, err := generateSelfSignedCert() + assert.NoError(t, err) + + certFile, err := os.CreateTemp("", "cert*.pem") + assert.NoError(t, err) + + keyFile, err := os.CreateTemp("", "key*.pem") + assert.NoError(t, err) + + t.Cleanup(func() { + os.Remove(certFile.Name()) + os.Remove(keyFile.Name()) + }) + + _, err = certFile.Write(cert) + assert.NoError(t, err) + + _, err = keyFile.Write(key) + assert.NoError(t, err) + + err = certFile.Close() + assert.NoError(t, err) + err = keyFile.Close() + assert.NoError(t, err) + config := server.Config{ - Host: "localhost", - Port: "0", - AttestedTLS: true, + Host: "localhost", + Port: "0", + CertFile: certFile.Name(), + KeyFile: keyFile.Name(), + ServerCAFile: certFile.Name(), + ClientCAFile: certFile.Name(), } logBuffer := &ThreadSafeBuffer{} @@ -152,16 +171,15 @@ func TestServerStartWithAttestedTLS(t *testing.T) { wg.Wait() - time.Sleep(100 * time.Millisecond) + time.Sleep(200 * time.Millisecond) cancel() - time.Sleep(1000 * time.Millisecond) + time.Sleep(200 * time.Millisecond) logContent := logBuffer.String() - assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with Attested TLS") - - qp.AssertExpectations(t) + fmt.Println(logContent) + assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") } func TestServerStop(t *testing.T) { @@ -246,3 +264,166 @@ func (b *ThreadSafeBuffer) String() string { defer b.mu.Unlock() return b.buffer.String() } + +func TestServerInitializationAndStartup(t *testing.T) { + testCases := []struct { + name string + config server.Config + expectedLog string + expectError bool + setupCallback func(*testing.T, *server.Config, *ThreadSafeBuffer) + }{ + { + name: "Non-TLS Server Startup", + config: server.Config{ + Host: "localhost", + Port: "0", + }, + expectedLog: "TestServer service gRPC server listening at localhost:0 without TLS", + }, + { + name: "TLS Server Startup with Self-Signed Certificate", + config: server.Config{ + Host: "localhost", + Port: "0", + }, + setupCallback: setupTLSConfig, + expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS", + }, + { + name: "TLS Server Startup with Invalid Certificates", + config: server.Config{ + Host: "localhost", + Port: "0", + CertFile: "invalid", + KeyFile: "invalid", + }, + expectError: true, + expectedLog: "failed to load auth certificates", + }, + { + name: "mTLS Server Startup", + config: server.Config{ + Host: "localhost", + Port: "0", + }, + setupCallback: setupMTLSConfig, + expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS", + }, + { + name: "mTLS Server Startup with Invalid Root CA", + config: server.Config{ + Host: "localhost", + Port: "0", + ServerCAFile: "invalid", + }, + setupCallback: setupInvalidRootCAConfig, + expectError: true, + expectedLog: "failed to append root ca to tls.Config", + }, + { + name: "mTLS Server Startup with Invalid Client CA", + config: server.Config{ + Host: "localhost", + Port: "0", + ServerCAFile: "invalid", + }, + setupCallback: setupInvalidClientCAConfig, + expectError: true, + expectedLog: "failed to append client ca to tls.Config", + }, + { + name: "Attested TLS Server Startup", + config: server.Config{ + Host: "localhost", + Port: "0", + AttestedTLS: true, + }, + expectedLog: "TestServer service gRPC server listening at localhost:0 with Attested TLS", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if tc.setupCallback != nil { + tc.setupCallback(t, &tc.config, nil) + } + + logBuffer := &ThreadSafeBuffer{} + logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) + qp := new(mocks.QuoteProvider) + authSvc := new(authmocks.Authenticator) + + srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, qp, authSvc) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + wg.Done() + err := srv.Start() + if tc.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedLog) + } else { + assert.NoError(t, err) + } + }() + + wg.Wait() + + time.Sleep(200 * time.Millisecond) + + cancel() + + time.Sleep(200 * time.Millisecond) + + if !tc.expectError { + logContent := logBuffer.String() + fmt.Println(logContent) + assert.Contains(t, logContent, tc.expectedLog) + } + }) + } +} + +func setupTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { + cert, key, err := generateSelfSignedCert() + assert.NoError(t, err) + + config.CertFile = string(cert) + config.KeyFile = string(key) +} + +func setupMTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { + cert, key, err := generateSelfSignedCert() + assert.NoError(t, err) + + config.CertFile = string(cert) + config.KeyFile = string(key) + config.ServerCAFile = string(cert) + config.ClientCAFile = string(cert) +} + +func setupInvalidRootCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { + cert, key, err := generateSelfSignedCert() + assert.NoError(t, err) + + config.CertFile = string(cert) + config.KeyFile = string(key) + config.ServerCAFile = "invalid" + config.ClientCAFile = string(cert) +} + +func setupInvalidClientCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { + cert, key, err := generateSelfSignedCert() + assert.NoError(t, err) + + config.CertFile = string(cert) + config.KeyFile = string(key) + config.ClientCAFile = "invalid" + config.ServerCAFile = string(cert) +} diff --git a/pkg/clients/grpc/agent/agent.go b/pkg/clients/grpc/agent/agent.go index de063bfb..9e514927 100644 --- a/pkg/clients/grpc/agent/agent.go +++ b/pkg/clients/grpc/agent/agent.go @@ -20,7 +20,7 @@ func NewAgentClient(ctx context.Context, cfg grpc.Config) (grpc.Client, agent.Ag return nil, nil, err } - if client.Secure() != grpc.WithATLS { + if client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS { health := grpchealth.NewHealthClient(client.Connection()) resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{ Service: "agent", diff --git a/pkg/clients/grpc/connect.go b/pkg/clients/grpc/connect.go index 7dd4cd8b..c091bfbf 100644 --- a/pkg/clients/grpc/connect.go +++ b/pkg/clients/grpc/connect.go @@ -35,6 +35,7 @@ const ( const ( AttestationReportSize = 0x4A0 WithATLS = "with aTLS" + WithTLS = "with TLS" ) var ( @@ -102,7 +103,7 @@ func (c *client) Close() error { func (c *client) Secure() string { switch c.secure { case withTLS: - return "with TLS" + return WithTLS case withmTLS: return "with mTLS" case withaTLS: From 28c751113dda3eec1d058928c2cee52e79a1837c Mon Sep 17 00:00:00 2001 From: WashingtonKK Date: Fri, 1 Nov 2024 14:09:48 +0300 Subject: [PATCH 2/4] restructure grpc configs Signed-off-by: WashingtonKK enhance clients Signed-off-by: WashingtonKK restructure config Signed-off-by: WashingtonKK refactor Signed-off-by: WashingtonKK rebase Signed-off-by: WashingtonKK rebase Signed-off-by: WashingtonKK use separate configuration Signed-off-by: WashingtonKK fix tests Signed-off-by: WashingtonKK fix config Signed-off-by: WashingtonKK refactor Signed-off-by: WashingtonKK Lint Signed-off-by: WashingtonKK fix tests Signed-off-by: WashingtonKK add tests Signed-off-by: WashingtonKK add test case Signed-off-by: WashingtonKK add test case Signed-off-by: WashingtonKK refactor Signed-off-by: WashingtonKK further refactor' Signed-off-by: WashingtonKK add tests Signed-off-by: WashingtonKK rebase Signed-off-by: WashingtonKK --- cli/sdk.go | 4 +- cmd/agent/main.go | 22 +- cmd/cli/main.go | 2 +- cmd/manager/main.go | 4 +- go.mod | 4 +- internal/server/grpc/grpc.go | 162 +++++++------ internal/server/grpc/grpc_test.go | 279 ++++++++++++++++++----- internal/server/server.go | 31 ++- pkg/clients/grpc/agent/agent.go | 2 +- pkg/clients/grpc/agent/agent_test.go | 26 ++- pkg/clients/grpc/connect.go | 124 ++++++---- pkg/clients/grpc/connect_test.go | 72 +++++- pkg/clients/grpc/manager/manager.go | 2 +- pkg/clients/grpc/manager/manager_test.go | 13 +- test/computations/main.go | 6 +- 15 files changed, 521 insertions(+), 232 deletions(-) diff --git a/cli/sdk.go b/cli/sdk.go index aa545a2c..672cf500 100644 --- a/cli/sdk.go +++ b/cli/sdk.go @@ -15,12 +15,12 @@ var Verbose bool type CLI struct { agentSDK sdk.SDK - config grpc.Config + config grpc.AgentClientConfig client grpc.Client connectErr error } -func New(config grpc.Config) *CLI { +func New(config grpc.AgentClientConfig) *CLI { return &CLI{ config: config, } diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 1c7f69d2..ab088154 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -97,14 +97,18 @@ func main() { svc := newService(ctx, logger, eventSvc, cfg, qp) - grpcServerConfig := server.Config{ - Port: cfg.AgentConfig.Port, - Host: cfg.AgentConfig.Host, - CertFile: cfg.AgentConfig.CertFile, - KeyFile: cfg.AgentConfig.KeyFile, - ServerCAFile: cfg.AgentConfig.ServerCAFile, - ClientCAFile: cfg.AgentConfig.ClientCAFile, - AttestedTLS: cfg.AgentConfig.AttestedTls, + agentGrpcServerConfig := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: cfg.AgentConfig.Host, + Port: cfg.AgentConfig.Port, + CertFile: cfg.AgentConfig.CertFile, + KeyFile: cfg.AgentConfig.KeyFile, + ServerCAFile: cfg.AgentConfig.ServerCAFile, + ClientCAFile: cfg.AgentConfig.ClientCAFile, + }, + }, + AttestedTLS: cfg.AgentConfig.AttestedTls, } registerAgentServiceServer := func(srv *grpc.Server) { @@ -119,7 +123,7 @@ func main() { return } - gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, qp, authSvc) + gs := grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, logger, qp, authSvc) g.Go(func() error { for { diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 63e76575..916e3703 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -91,7 +91,7 @@ func main() { return } - agentGRPCConfig := grpc.Config{} + agentGRPCConfig := grpc.AgentClientConfig{} if err := env.ParseWithOptions(&agentGRPCConfig, env.Options{Prefix: envPrefixAgentGRPC}); err != nil { message := color.New(color.FgRed).Sprintf("failed to load %s gRPC client configuration : %s", svcName, err) rootCmd.Println(message) diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 226607f9..bb9933d4 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -25,7 +25,7 @@ import ( "github.com/ultravioletrs/cocos/manager/events" "github.com/ultravioletrs/cocos/manager/qemu" "github.com/ultravioletrs/cocos/manager/tracing" - "github.com/ultravioletrs/cocos/pkg/clients/grpc" + pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc" managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" @@ -92,7 +92,7 @@ func main() { args := qemuCfg.ConstructQemuArgs() logger.Info(strings.Join(args, " ")) - managerGRPCConfig := grpc.Config{} + managerGRPCConfig := pkggrpc.ManagerClientConfig{} if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil { logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)) exitCode = 1 diff --git a/go.mod b/go.mod index 8fbd0b15..48c3b8c9 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/ultravioletrs/cocos -go 1.22.7 - -toolchain go1.23.1 +go 1.23.0 require ( github.com/absmach/magistrala v0.14.1-0.20240709113739-04c359462746 diff --git a/internal/server/grpc/grpc.go b/internal/server/grpc/grpc.go index c3ce35a4..33ae890c 100644 --- a/internal/server/grpc/grpc.go +++ b/internal/server/grpc/grpc.go @@ -60,8 +60,9 @@ type serviceRegister func(srv *grpc.Server) var _ server.Server = (*Server)(nil) -func New(ctx context.Context, cancel context.CancelFunc, name string, config server.Config, registerService serviceRegister, logger *slog.Logger, qp client.QuoteProvider, authSvc auth.Authenticator) server.Server { - listenFullAddress := fmt.Sprintf("%s:%s", config.Host, config.Port) +func New(ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, registerService serviceRegister, logger *slog.Logger, qp client.QuoteProvider, authSvc auth.Authenticator) server.Server { + base := config.GetBaseConfig() + listenFullAddress := fmt.Sprintf("%s:%s", base.Host, base.Port) return &Server{ BaseServer: server.BaseServer{ Ctx: ctx, @@ -91,101 +92,98 @@ func (s *Server) Start() error { creds := grpc.Creds(insecure.NewCredentials()) var listener net.Listener = nil + switch c := s.Config.(type) { + case server.AgentConfig: + switch { + case c.AttestedTLS: + certificateBytes, privateKeyBytes, err := generateCertificatesForATLS() + if err != nil { + return fmt.Errorf("failed to create certificate: %w", err) + } - switch { - case s.Config.AttestedTLS: - certificateBytes, privateKeyBytes, err := generateCertificatesForATLS() - if err != nil { - return fmt.Errorf("failed to create certificate: %w", err) - } - - certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes) - if err != nil { - return fmt.Errorf("falied due to invalid key pair: %w", err) - } - - tlsConfig := &tls.Config{ - ClientAuth: tls.NoClientCert, - Certificates: []tls.Certificate{certificate}, - } + certificate, err := tls.X509KeyPair(certificateBytes, privateKeyBytes) + if err != nil { + return fmt.Errorf("falied due to invalid key pair: %w", err) + } - creds = grpc.Creds(credentials.NewTLS(tlsConfig)) + tlsConfig := &tls.Config{ + ClientAuth: tls.NoClientCert, + Certificates: []tls.Certificate{certificate}, + } - listener, err = atls.Listen( - s.Address, - certificateBytes, - privateKeyBytes, - ) - if err != nil { - return fmt.Errorf("failed to create Listener for aTLS: %w", err) - } - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address)) + creds = grpc.Creds(credentials.NewTLS(tlsConfig)) - case s.Config.CertFile != "" || s.Config.KeyFile != "": - certificate, err := loadX509KeyPair(s.Config.CertFile, s.Config.KeyFile) - if err != nil { - return fmt.Errorf("failed to load auth certificates: %w", err) - } - tlsConfig := &tls.Config{ - ClientAuth: tls.NoClientCert, - Certificates: []tls.Certificate{certificate}, - } + listener, err = atls.Listen( + s.Address, + certificateBytes, + privateKeyBytes, + ) + if err != nil { + return fmt.Errorf("failed to create Listener for aTLS: %w", err) + } + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address)) - var mtlsCA string - // Loading Server CA file - rootCA, err := loadCertFile(s.Config.ServerCAFile) - if err != nil { - return fmt.Errorf("failed to load root ca file: %w", err) - } - if len(rootCA) > 0 { - if tlsConfig.RootCAs == nil { - tlsConfig.RootCAs = x509.NewCertPool() + case c.CertFile != "" || c.KeyFile != "": + certificate, err := loadX509KeyPair(c.CertFile, c.KeyFile) + if err != nil { + return fmt.Errorf("failed to load auth certificates: %w", err) } - if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) { - return fmt.Errorf("failed to append root ca to tls.Config") + tlsConfig := &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + Certificates: []tls.Certificate{certificate}, } - mtlsCA = fmt.Sprintf("root ca %s", s.Config.ServerCAFile) - } - // Loading Client CA File - clientCA, err := loadCertFile(s.Config.ClientCAFile) - if err != nil { - return fmt.Errorf("failed to load client ca file: %w", err) - } - if len(clientCA) > 0 { - if tlsConfig.ClientCAs == nil { - tlsConfig.ClientCAs = x509.NewCertPool() + var mtlsCA string + // Loading Server CA file + rootCA, err := loadCertFile(c.ServerCAFile) + if err != nil { + return fmt.Errorf("failed to load root ca file: %w", err) } - if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) { - return fmt.Errorf("failed to append client ca to tls.Config") + if len(rootCA) > 0 { + if tlsConfig.RootCAs == nil { + tlsConfig.RootCAs = x509.NewCertPool() + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) { + return fmt.Errorf("failed to append root ca to tls.Config") + } + mtlsCA = fmt.Sprintf("root ca %s", c.ServerCAFile) } - mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile) - } - if mtlsCA != "" { - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - } + // Loading Client CA File + clientCA, err := loadCertFile(c.ClientCAFile) + if err != nil { + return fmt.Errorf("failed to load client ca file: %w", err) + } + if len(clientCA) > 0 { + if tlsConfig.ClientCAs == nil { + tlsConfig.ClientCAs = x509.NewCertPool() + } + if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) { + return fmt.Errorf("failed to append client ca to tls.Config") + } + mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, c.ClientCAFile) + } + creds = grpc.Creds(credentials.NewTLS(tlsConfig)) + switch { + case mtlsCA != "": + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, c.CertFile, c.KeyFile, mtlsCA)) + default: + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, c.CertFile, c.KeyFile)) + } - creds = grpc.Creds(credentials.NewTLS(tlsConfig)) - switch { - case mtlsCA != "": - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS", s.Name, s.Address)) + listener, err = net.Listen("tcp", s.Address) + if err != nil { + return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) + } default: - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS", s.Name, s.Address)) - } + var err error - listener, err = net.Listen("tcp", s.Address) - if err != nil { - return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) - } - default: - var err error - - listener, err = net.Listen("tcp", s.Address) - if err != nil { - return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) + listener, err = net.Listen("tcp", s.Address) + if err != nil { + return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) + } + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address)) } - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address)) } grpcServerOptions = append(grpcServerOptions, creds) diff --git a/internal/server/grpc/grpc_test.go b/internal/server/grpc/grpc_test.go index 248980e4..d9a8f9e1 100644 --- a/internal/server/grpc/grpc_test.go +++ b/internal/server/grpc/grpc_test.go @@ -38,9 +38,13 @@ func TestNew(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - config := server.Config{ - Host: "localhost", - Port: "50051", + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "50051", + }, + }, } logger := slog.Default() qp := new(mocks.QuoteProvider) @@ -52,7 +56,7 @@ func TestNew(t *testing.T) { assert.IsType(t, &Server{}, srv) } -func TestServerStartWithTLSFile(t *testing.T) { +func TestServerStartWithTLS(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cert, key, err := generateSelfSignedCert() @@ -80,11 +84,62 @@ func TestServerStartWithTLSFile(t *testing.T) { err = keyFile.Close() assert.NoError(t, err) - config := server.Config{ - Host: "localhost", - Port: "0", - CertFile: certFile.Name(), - KeyFile: keyFile.Name(), + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + CertFile: certFile.Name(), + KeyFile: keyFile.Name(), + }, + }, + } + + logBuffer := &ThreadSafeBuffer{} + logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) + qp := new(mocks.QuoteProvider) + authSvc := new(authmocks.Authenticator) + + srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + wg.Done() + err := srv.Start() + assert.NoError(t, err) + }() + + wg.Wait() + + time.Sleep(200 * time.Millisecond) + + cancel() + + time.Sleep(200 * time.Millisecond) + + logContent := logBuffer.String() + fmt.Println(logContent) + assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") +} + +func TestServerStartWithMTLS(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + caCertFile, clientCertFile, clientKeyFile, err := createCertificatesFiles() + assert.NoError(t, err) + + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + CertFile: string(clientCertFile), + KeyFile: string(clientKeyFile), + ServerCAFile: caCertFile, + }, + }, } logBuffer := &ThreadSafeBuffer{} @@ -144,13 +199,17 @@ func TestServerStartWithmTLSFile(t *testing.T) { err = keyFile.Close() assert.NoError(t, err) - config := server.Config{ - Host: "localhost", - Port: "0", - CertFile: certFile.Name(), - KeyFile: keyFile.Name(), - ServerCAFile: certFile.Name(), - ClientCAFile: certFile.Name(), + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + CertFile: certFile.Name(), + KeyFile: keyFile.Name(), + ServerCAFile: certFile.Name(), + ClientCAFile: certFile.Name(), + }, + }, } logBuffer := &ThreadSafeBuffer{} @@ -185,9 +244,13 @@ func TestServerStartWithmTLSFile(t *testing.T) { func TestServerStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - config := server.Config{ - Host: "localhost", - Port: "0", + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, } buf := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug})) @@ -268,54 +331,74 @@ func (b *ThreadSafeBuffer) String() string { func TestServerInitializationAndStartup(t *testing.T) { testCases := []struct { name string - config server.Config + config server.AgentConfig expectedLog string expectError bool - setupCallback func(*testing.T, *server.Config, *ThreadSafeBuffer) + setupCallback func(*testing.T, *server.AgentConfig, *ThreadSafeBuffer) }{ { name: "Non-TLS Server Startup", - config: server.Config{ - Host: "localhost", - Port: "0", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, }, expectedLog: "TestServer service gRPC server listening at localhost:0 without TLS", }, { name: "TLS Server Startup with Self-Signed Certificate", - config: server.Config{ - Host: "localhost", - Port: "0", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, }, setupCallback: setupTLSConfig, expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS", }, { name: "TLS Server Startup with Invalid Certificates", - config: server.Config{ - Host: "localhost", - Port: "0", - CertFile: "invalid", - KeyFile: "invalid", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + CertFile: "invalid", + KeyFile: "invalid", + }, + }, }, expectError: true, expectedLog: "failed to load auth certificates", }, { name: "mTLS Server Startup", - config: server.Config{ - Host: "localhost", - Port: "0", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, }, setupCallback: setupMTLSConfig, expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS", }, { name: "mTLS Server Startup with Invalid Root CA", - config: server.Config{ - Host: "localhost", - Port: "0", - ServerCAFile: "invalid", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + ServerCAFile: "invalid", + }, + }, }, setupCallback: setupInvalidRootCAConfig, expectError: true, @@ -323,10 +406,14 @@ func TestServerInitializationAndStartup(t *testing.T) { }, { name: "mTLS Server Startup with Invalid Client CA", - config: server.Config{ - Host: "localhost", - Port: "0", - ServerCAFile: "invalid", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + ServerCAFile: "invalid", + }, + }, }, setupCallback: setupInvalidClientCAConfig, expectError: true, @@ -334,9 +421,13 @@ func TestServerInitializationAndStartup(t *testing.T) { }, { name: "Attested TLS Server Startup", - config: server.Config{ - Host: "localhost", - Port: "0", + config: server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + }, + }, AttestedTLS: true, }, expectedLog: "TestServer service gRPC server listening at localhost:0 with Attested TLS", @@ -347,7 +438,6 @@ func TestServerInitializationAndStartup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if tc.setupCallback != nil { tc.setupCallback(t, &tc.config, nil) } @@ -358,7 +448,6 @@ func TestServerInitializationAndStartup(t *testing.T) { authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, qp, authSvc) - var wg sync.WaitGroup wg.Add(1) @@ -390,7 +479,7 @@ func TestServerInitializationAndStartup(t *testing.T) { } } -func setupTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { +func setupTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { cert, key, err := generateSelfSignedCert() assert.NoError(t, err) @@ -398,7 +487,7 @@ func setupTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { config.KeyFile = string(key) } -func setupMTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { +func setupMTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { cert, key, err := generateSelfSignedCert() assert.NoError(t, err) @@ -408,7 +497,7 @@ func setupMTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { config.ClientCAFile = string(cert) } -func setupInvalidRootCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { +func setupInvalidRootCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { cert, key, err := generateSelfSignedCert() assert.NoError(t, err) @@ -418,7 +507,7 @@ func setupInvalidRootCAConfig(t *testing.T, config *server.Config, _ *ThreadSafe config.ClientCAFile = string(cert) } -func setupInvalidClientCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) { +func setupInvalidClientCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { cert, key, err := generateSelfSignedCert() assert.NoError(t, err) @@ -427,3 +516,89 @@ func setupInvalidClientCAConfig(t *testing.T, config *server.Config, _ *ThreadSa config.ClientCAFile = "invalid" config.ServerCAFile = string(cert) } + +func createCertificatesFiles() (string, string, string, error) { + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", "", err + } + + caTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return "", "", "", err + } + + caCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})) + if err != nil { + return "", "", "", err + } + + clientKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", "", err + } + + clientTemplate := x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey) + if err != nil { + return "", "", "", err + } + + clientCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})) + if err != nil { + return "", "", "", err + } + + clientKeyFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)})) + if err != nil { + return "", "", "", err + } + + return caCertFile, clientCertFile, clientKeyFile, nil +} + +func createTempFile(data []byte) (string, error) { + file, err := createTempFileHandle() + if err != nil { + return "", err + } + + _, err = file.Write(data) + if err != nil { + return "", err + } + + err = file.Close() + if err != nil { + return "", err + } + + return file.Name(), nil +} + +func createTempFileHandle() (*os.File, error) { + return os.CreateTemp("", "test") +} diff --git a/internal/server/server.go b/internal/server/server.go index c50e9a11..188c2495 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -16,14 +16,25 @@ type Server interface { Stop() error } -type Config struct { - Host string `env:"HOST" envDefault:""` - Port string `env:"PORT" envDefault:""` +type ServerConfiguration interface { + GetBaseConfig() ServerConfig +} + +type BaseConfig struct { + Host string `env:"HOST" envDefault:"localhost"` + Port string `env:"PORT" envDefault:"7001"` + ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` CertFile string `env:"SERVER_CERT" envDefault:""` KeyFile string `env:"SERVER_KEY" envDefault:""` - ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` ClientCAFile string `env:"CLIENT_CA_CERTS" envDefault:""` - AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` +} + +type ServerConfig struct { + BaseConfig +} +type AgentConfig struct { + ServerConfig + AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` } type BaseServer struct { @@ -31,11 +42,19 @@ type BaseServer struct { Cancel context.CancelFunc Name string Address string - Config Config + Config ServerConfiguration Logger *slog.Logger Protocol string } +func (s ServerConfig) GetBaseConfig() ServerConfig { + return s +} + +func (a AgentConfig) GetBaseConfig() ServerConfig { + return a.ServerConfig +} + func stopAllServer(servers ...Server) error { var errs []error for _, server := range servers { diff --git a/pkg/clients/grpc/agent/agent.go b/pkg/clients/grpc/agent/agent.go index 9e514927..4880a2c2 100644 --- a/pkg/clients/grpc/agent/agent.go +++ b/pkg/clients/grpc/agent/agent.go @@ -14,7 +14,7 @@ import ( var ErrAgentServiceUnavailable = errors.New("agent service is unavailable") // NewAgentClient creates new agent gRPC client instance. -func NewAgentClient(ctx context.Context, cfg grpc.Config) (grpc.Client, agent.AgentServiceClient, error) { +func NewAgentClient(ctx context.Context, cfg grpc.AgentClientConfig) (grpc.Client, agent.AgentServiceClient, error) { client, err := grpc.NewClient(cfg) if err != nil { return nil, nil, err diff --git a/pkg/clients/grpc/agent/agent_test.go b/pkg/clients/grpc/agent/agent_test.go index 539e474d..ea6b8402 100644 --- a/pkg/clients/grpc/agent/agent_test.go +++ b/pkg/clients/grpc/agent/agent_test.go @@ -78,32 +78,38 @@ func TestAgentClientIntegration(t *testing.T) { tests := []struct { name string serverRunning bool - config pkggrpc.Config + config pkggrpc.AgentClientConfig err error }{ { name: "successful connection", serverRunning: true, - config: pkggrpc.Config{ - URL: testServer.listenAddr, - Timeout: 1, + config: pkggrpc.AgentClientConfig{ + BaseConfig: pkggrpc.BaseConfig{ + URL: testServer.listenAddr, + Timeout: 1, + }, }, err: nil, }, { name: "server not healthy", serverRunning: false, - config: pkggrpc.Config{ - URL: "", - Timeout: 1, + config: pkggrpc.AgentClientConfig{ + BaseConfig: pkggrpc.BaseConfig{ + URL: "", + Timeout: 1, + }, }, err: ErrAgentServiceUnavailable, }, { name: "invalid config, missing AttestationPolicy with aTLS", - config: pkggrpc.Config{ - URL: testServer.listenAddr, - Timeout: 1, + config: pkggrpc.AgentClientConfig{ + BaseConfig: pkggrpc.BaseConfig{ + URL: testServer.listenAddr, + Timeout: 1, + }, AttestedTLS: true, }, err: pkggrpc.ErrAttestationPolicyMissing, diff --git a/pkg/clients/grpc/connect.go b/pkg/clients/grpc/connect.go index c091bfbf..aeb24db9 100644 --- a/pkg/clients/grpc/connect.go +++ b/pkg/clients/grpc/connect.go @@ -50,14 +50,38 @@ var ( errFailedToLoadRootCA = errors.New("failed to load root ca file") ) -type Config struct { - ClientCert string `env:"CLIENT_CERT" envDefault:""` - ClientKey string `env:"CLIENT_KEY" envDefault:""` - ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` - URL string `env:"URL" envDefault:"localhost:7001"` - Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"` - AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` - AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""` +type ClientConfiguration interface { + GetBaseConfig() BaseConfig +} + +type BaseConfig struct { + URL string `env:"URL" envDefault:"localhost:7001"` + Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"` + ClientCert string `env:"CLIENT_CERT" envDefault:""` + ClientKey string `env:"CLIENT_KEY" envDefault:""` + ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` +} + +type AgentClientConfig struct { + BaseConfig + AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""` + AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` +} + +type ManagerClientConfig struct { + BaseConfig +} + +func (a BaseConfig) GetBaseConfig() BaseConfig { + return a +} + +func (a AgentClientConfig) GetBaseConfig() BaseConfig { + return a.BaseConfig +} + +func (a ManagerClientConfig) GetBaseConfig() BaseConfig { + return a.BaseConfig } type Client interface { @@ -73,13 +97,13 @@ type Client interface { type client struct { *grpc.ClientConn - cfg Config + cfg ClientConfiguration secure security } var _ Client = (*client)(nil) -func NewClient(cfg Config) (Client, error) { +func NewClient(cfg ClientConfiguration) (Client, error) { conn, secure, err := connect(cfg) if err != nil { return nil, err @@ -120,15 +144,15 @@ func (c *client) Connection() *grpc.ClientConn { } // connect creates new gRPC client and connect to gRPC server. -func connect(cfg Config) (*grpc.ClientConn, security, error) { +func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) { opts := []grpc.DialOption{ grpc.WithStatsHandler(otelgrpc.NewClientHandler()), } secure := withoutTLS - tc := insecure.NewCredentials() + var tc credentials.TransportCredentials - if cfg.AttestedTLS { - err := ReadAttestationPolicy(cfg.AttestationPolicy, "eprovider.AttConfigurationSEVSNP) + if agcfg, ok := cfg.(AgentClientConfig); ok && agcfg.AttestedTLS { + err := ReadAttestationPolicy(agcfg.AttestationPolicy, "eprovider.AttConfigurationSEVSNP) if err != nil { return nil, secure, errors.Wrap(fmt.Errorf("failed to read Attestation Policy"), err) } @@ -141,46 +165,60 @@ func connect(cfg Config) (*grpc.ClientConn, security, error) { opts = append(opts, grpc.WithContextDialer(CustomDialer)) secure = withaTLS } else { - if cfg.ServerCAFile != "" { - tlsConfig := &tls.Config{} - - // Loading root ca certificates file - rootCA, err := os.ReadFile(cfg.ServerCAFile) - if err != nil { - return nil, secure, errors.Wrap(errFailedToLoadRootCA, err) - } - if len(rootCA) > 0 { - capool := x509.NewCertPool() - if !capool.AppendCertsFromPEM(rootCA) { - return nil, secure, fmt.Errorf("failed to append root ca to tls.Config") - } - tlsConfig.RootCAs = capool - secure = withTLS - } - - // Loading mTLS certificates file - if cfg.ClientCert != "" || cfg.ClientKey != "" { - certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey) - if err != nil { - return nil, secure, errors.Wrap(errFailedToLoadClientCertKey, err) - } - tlsConfig.Certificates = []tls.Certificate{certificate} - secure = withmTLS - } - - tc = credentials.NewTLS(tlsConfig) + conf := cfg.GetBaseConfig() + transportCreds, err, sec := loadTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey) + if err != nil { + return nil, secure, err } + tc = transportCreds + secure = sec } opts = append(opts, grpc.WithTransportCredentials(tc)) - conn, err := grpc.NewClient(cfg.URL, opts...) + conn, err := grpc.NewClient(cfg.GetBaseConfig().URL, opts...) if err != nil { return nil, secure, errors.Wrap(errGrpcConnect, err) } return conn, secure, nil } +func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, error, security) { + tlsConfig := &tls.Config{} + secure := withoutTLS + tc := insecure.NewCredentials() + + // Load Root CA certificates + if serverCAFile != "" { + rootCA, err := os.ReadFile(serverCAFile) + if err != nil { + return nil, errors.Wrap(errFailedToLoadRootCA, err), secure + } + if len(rootCA) > 0 { + capool := x509.NewCertPool() + if !capool.AppendCertsFromPEM(rootCA) { + return nil, fmt.Errorf("failed to append root ca to tls.Config"), secure + } + tlsConfig.RootCAs = capool + secure = withTLS + tc = credentials.NewTLS(tlsConfig) + } + } + + // Load mTLS certificates + if clientCert != "" || clientKey != "" { + certificate, err := tls.LoadX509KeyPair(clientCert, clientKey) + if err != nil { + return nil, errors.Wrap(errFailedToLoadClientCertKey, err), secure + } + tlsConfig.Certificates = []tls.Certificate{certificate} + secure = withmTLS + tc = credentials.NewTLS(tlsConfig) + } + + return tc, nil, secure +} + func ReadAttestationPolicy(manifestPath string, attestationConfiguration *check.Config) error { if manifestPath != "" { manifest, err := os.ReadFile(manifestPath) diff --git a/pkg/clients/grpc/connect_test.go b/pkg/clients/grpc/connect_test.go index fd3ed29e..863c9654 100644 --- a/pkg/clients/grpc/connect_test.go +++ b/pkg/clients/grpc/connect_test.go @@ -11,6 +11,7 @@ import ( "fmt" "math/big" "os" + "strings" "testing" "time" @@ -31,14 +32,15 @@ func TestNewClient(t *testing.T) { }) tests := []struct { - name string - cfg Config - wantErr bool - err error + name string + cfg BaseConfig + agentCfg AgentClientConfig + wantErr bool + err error }{ { name: "Success without TLS", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", }, wantErr: false, @@ -46,7 +48,7 @@ func TestNewClient(t *testing.T) { }, { name: "Success with TLS", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: caCertFile, }, @@ -55,7 +57,7 @@ func TestNewClient(t *testing.T) { }, { name: "Success with mTLS", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: caCertFile, ClientCert: clientCertFile, @@ -64,9 +66,52 @@ func TestNewClient(t *testing.T) { wantErr: false, err: nil, }, + { + name: "Success agent client with mTLS", + agentCfg: AgentClientConfig{ + BaseConfig: BaseConfig{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + ClientCert: clientCertFile, + ClientKey: clientKeyFile, + }, + }, + wantErr: false, + err: nil, + }, + { + name: "Success agent client with aTLS", + agentCfg: AgentClientConfig{ + BaseConfig: BaseConfig{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + ClientCert: clientCertFile, + ClientKey: clientKeyFile, + }, + AttestedTLS: true, + AttestationPolicy: "../../../scripts/attestation_policy/attestation_policy.json", + }, + wantErr: false, + err: nil, + }, + { + name: "Failed agent client with aTLS", + agentCfg: AgentClientConfig{ + BaseConfig: BaseConfig{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + ClientCert: clientCertFile, + ClientKey: clientKeyFile, + }, + AttestedTLS: true, + AttestationPolicy: "no such file", + }, + wantErr: true, + err: fmt.Errorf("failed to read Attestation Policy"), + }, { name: "Fail with invalid ServerCAFile", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: "nonexistent.pem", }, @@ -75,7 +120,7 @@ func TestNewClient(t *testing.T) { }, { name: "Fail with invalid ClientCert", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: caCertFile, ClientCert: "nonexistent.pem", @@ -86,7 +131,7 @@ func TestNewClient(t *testing.T) { }, { name: "Fail with invalid ClientKey", - cfg: Config{ + cfg: BaseConfig{ URL: "localhost:7001", ServerCAFile: caCertFile, ClientCert: clientCertFile, @@ -99,7 +144,12 @@ func TestNewClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client, err := NewClient(tt.cfg) + var client Client + if strings.Contains(tt.name, "agent client") { + client, err = NewClient(tt.agentCfg) + } else { + client, err = NewClient(tt.cfg) + } assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) if tt.wantErr { assert.Error(t, err) diff --git a/pkg/clients/grpc/manager/manager.go b/pkg/clients/grpc/manager/manager.go index 3f937387..736796bc 100644 --- a/pkg/clients/grpc/manager/manager.go +++ b/pkg/clients/grpc/manager/manager.go @@ -8,7 +8,7 @@ import ( ) // NewManagerClient creates new manager gRPC client instance. -func NewManagerClient(cfg grpc.Config) (grpc.Client, manager.ManagerServiceClient, error) { +func NewManagerClient(cfg grpc.ManagerClientConfig) (grpc.Client, manager.ManagerServiceClient, error) { client, err := grpc.NewClient(cfg) if err != nil { return nil, nil, err diff --git a/pkg/clients/grpc/manager/manager_test.go b/pkg/clients/grpc/manager/manager_test.go index 2ee2592f..49ae5627 100644 --- a/pkg/clients/grpc/manager/manager_test.go +++ b/pkg/clients/grpc/manager/manager_test.go @@ -13,21 +13,18 @@ import ( func TestNewManagerClient(t *testing.T) { tests := []struct { name string - cfg grpc.Config + cfg grpc.ManagerClientConfig err error }{ { name: "Valid config", - cfg: grpc.Config{ - URL: "localhost:7001", + cfg: grpc.ManagerClientConfig{ + BaseConfig: grpc.BaseConfig{ + URL: "localhost:7001", + }, }, err: nil, }, - { - name: "invalid config, missing AttestationPolicy with aTLS", - cfg: grpc.Config{AttestedTLS: true}, - err: grpc.ErrAttestationPolicyMissing, - }, } for _, tt := range tests { diff --git a/test/computations/main.go b/test/computations/main.go index b2e3dda2..a45fa7ee 100644 --- a/test/computations/main.go +++ b/test/computations/main.go @@ -130,7 +130,11 @@ func main() { reflection.Register(srv) manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(incomingChan, &svc{logger: logger})) } - grpcServerConfig := server.Config{Port: defaultPort} + grpcServerConfig := server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Port: defaultPort, + }, + } if err := env.ParseWithOptions(&grpcServerConfig, env.Options{}); err != nil { logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)) return From 70a8ac549f97e22f7ebc074559d1b19b3f65e657 Mon Sep 17 00:00:00 2001 From: WashingtonKK Date: Wed, 4 Dec 2024 19:24:53 +0300 Subject: [PATCH 3/4] remove redundant code Signed-off-by: WashingtonKK --- internal/server/grpc/grpc_test.go | 70 ------------------------------- 1 file changed, 70 deletions(-) diff --git a/internal/server/grpc/grpc_test.go b/internal/server/grpc/grpc_test.go index d9a8f9e1..85ff1a72 100644 --- a/internal/server/grpc/grpc_test.go +++ b/internal/server/grpc/grpc_test.go @@ -171,76 +171,6 @@ func TestServerStartWithMTLS(t *testing.T) { assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") } -func TestServerStartWithmTLSFile(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - - cert, key, err := generateSelfSignedCert() - assert.NoError(t, err) - - certFile, err := os.CreateTemp("", "cert*.pem") - assert.NoError(t, err) - - keyFile, err := os.CreateTemp("", "key*.pem") - assert.NoError(t, err) - - t.Cleanup(func() { - os.Remove(certFile.Name()) - os.Remove(keyFile.Name()) - }) - - _, err = certFile.Write(cert) - assert.NoError(t, err) - - _, err = keyFile.Write(key) - assert.NoError(t, err) - - err = certFile.Close() - assert.NoError(t, err) - err = keyFile.Close() - assert.NoError(t, err) - - config := server.AgentConfig{ - ServerConfig: server.ServerConfig{ - BaseConfig: server.BaseConfig{ - Host: "localhost", - Port: "0", - CertFile: certFile.Name(), - KeyFile: keyFile.Name(), - ServerCAFile: certFile.Name(), - ClientCAFile: certFile.Name(), - }, - }, - } - - logBuffer := &ThreadSafeBuffer{} - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) - authSvc := new(authmocks.Authenticator) - - srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - wg.Done() - err := srv.Start() - assert.NoError(t, err) - }() - - wg.Wait() - - time.Sleep(200 * time.Millisecond) - - cancel() - - time.Sleep(200 * time.Millisecond) - - logContent := logBuffer.String() - fmt.Println(logContent) - assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") -} - func TestServerStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) From 58dc3ef8c3f85b34e4edcbbdb03b1dde5d4b71f1 Mon Sep 17 00:00:00 2001 From: WashingtonKK Date: Wed, 4 Dec 2024 20:04:45 +0300 Subject: [PATCH 4/4] fix test Signed-off-by: WashingtonKK --- internal/server/grpc/grpc_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/server/grpc/grpc_test.go b/internal/server/grpc/grpc_test.go index 85ff1a72..b5bcbaeb 100644 --- a/internal/server/grpc/grpc_test.go +++ b/internal/server/grpc/grpc_test.go @@ -56,7 +56,7 @@ func TestNew(t *testing.T) { assert.IsType(t, &Server{}, srv) } -func TestServerStartWithTLS(t *testing.T) { +func TestServerStartWithTLSFile(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cert, key, err := generateSelfSignedCert() @@ -124,7 +124,7 @@ func TestServerStartWithTLS(t *testing.T) { assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") } -func TestServerStartWithMTLS(t *testing.T) { +func TestServerStartWithmTLSFile(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) caCertFile, clientCertFile, clientKeyFile, err := createCertificatesFiles()