diff --git a/bootstrapper/cmd/bootstrapper/run.go b/bootstrapper/cmd/bootstrapper/run.go index 733444beec..95bd46b060 100644 --- a/bootstrapper/cmd/bootstrapper/run.go +++ b/bootstrapper/cmd/bootstrapper/run.go @@ -10,8 +10,11 @@ import ( "context" "fmt" "log/slog" + "log/syslog" "net" - "os" + "sync" + "syscall" + "time" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/clean" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption" @@ -32,7 +35,8 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl ) { log.With(slog.String("version", constants.BinaryVersion().String())).Info("Starting bootstrapper") - uuid, err := getDiskUUID() + disk := diskencryption.New() + uuid, err := getDiskUUID(disk) if err != nil { log.With(slog.Any("error", err)).Error("Failed to get disk UUID") } else { @@ -42,43 +46,58 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl nodeBootstrapped, err := initialize.IsNodeBootstrapped(openDevice) if err != nil { log.With(slog.Any("error", err)).Error("Failed to check if node was previously bootstrapped") - os.Exit(1) + reboot(fmt.Errorf("checking if node was previously bootstrapped: %w", err)) } if nodeBootstrapped { if err := kube.StartKubelet(); err != nil { log.With(slog.Any("error", err)).Error("Failed to restart kubelet") - os.Exit(1) + reboot(fmt.Errorf("restarting kubelet: %w", err)) } return } nodeLock := nodelock.New(openDevice) - initServer, err := initserver.New(context.Background(), nodeLock, kube, issuer, fileHandler, metadata, log) + initServer, err := initserver.New(context.Background(), nodeLock, kube, issuer, disk, fileHandler, metadata, log) if err != nil { log.With(slog.Any("error", err)).Error("Failed to create init server") - os.Exit(1) + reboot(fmt.Errorf("creating init server: %w", err)) } dialer := dialer.New(issuer, nil, &net.Dialer{}) - joinClient := joinclient.New(nodeLock, dialer, kube, metadata, log) + joinClient := joinclient.New(nodeLock, dialer, kube, metadata, disk, log) cleaner := clean.New().With(initServer).With(joinClient) go cleaner.Start() defer cleaner.Done() - joinClient.Start(cleaner) + var wg sync.WaitGroup - if err := initServer.Serve(bindIP, bindPort, cleaner); err != nil { - log.With(slog.Any("error", err)).Error("Failed to serve init server") - os.Exit(1) - } + wg.Add(1) + go func() { + defer wg.Done() + if err := joinClient.Start(cleaner); err != nil { + log.With(slog.Any("error", err)).Error("Failed to join cluster") + markDiskForReset(disk) + reboot(fmt.Errorf("joining cluster: %w", err)) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + if err := initServer.Serve(bindIP, bindPort, cleaner); err != nil { + log.With(slog.Any("error", err)).Error("Failed to serve init server") + markDiskForReset(disk) + reboot(fmt.Errorf("serving init server: %w", err)) + } + }() + wg.Wait() log.Info("bootstrapper done") } -func getDiskUUID() (string, error) { - disk := diskencryption.New() +func getDiskUUID(disk *diskencryption.DiskEncryption) (string, error) { free, err := disk.Open() if err != nil { return "", err @@ -87,6 +106,36 @@ func getDiskUUID() (string, error) { return disk.UUID() } +// markDiskForReset sets a token in the cryptsetup header of the disk to indicate the disk should be reset on next boot. +// This is used to reset all state of a node in case the bootstrapper encountered a non recoverable error +// after the node successfully retrieved a join ticket from the JoinService. +// As setting this token is safe as long as we are certain we don't need the data on the disk anymore, we call this +// unconditionally when either the JoinClient or the InitServer encounter an error. +// We don't call it before that, as the node may be restarting after a previous, successful bootstrapping, +// and now encountered a transient error on rejoining the cluster. Wiping the disk now would delete existing data. +func markDiskForReset(disk *diskencryption.DiskEncryption) { + free, err := disk.Open() + if err != nil { + return + } + defer free() + _ = disk.MarkDiskForReset() +} + +// reboot writes an error message to the system log and reboots the system. +// We call this instead of os.Exit() since failures in the bootstrapper usually require a node reset. +func reboot(e error) { + syslogWriter, err := syslog.New(syslog.LOG_EMERG|syslog.LOG_KERN, "bootstrapper") + if err != nil { + _ = syscall.Reboot(syscall.LINUX_REBOOT_CMD_RESTART) + } + _ = syslogWriter.Err(e.Error()) + _ = syslogWriter.Emerg("bootstrapper has encountered a non recoverable error. Rebooting...") + time.Sleep(time.Minute) // sleep to allow the message to be written to syslog and seen by the user + + _ = syscall.Reboot(syscall.LINUX_REBOOT_CMD_RESTART) +} + type clusterInitJoiner interface { joinclient.ClusterJoiner initserver.ClusterInitializer diff --git a/bootstrapper/internal/diskencryption/diskencryption.go b/bootstrapper/internal/diskencryption/diskencryption.go index eaf97e7ab4..e8fbcb4a20 100644 --- a/bootstrapper/internal/diskencryption/diskencryption.go +++ b/bootstrapper/internal/diskencryption/diskencryption.go @@ -60,6 +60,11 @@ func (c *DiskEncryption) UpdatePassphrase(passphrase string) error { return c.device.SetConstellationStateDiskToken(cryptsetup.SetDiskInitialized) } +// MarkDiskForReset marks the state disk as not initialized so it may be wiped (reset) on reboot. +func (c *DiskEncryption) MarkDiskForReset() error { + return c.device.SetConstellationStateDiskToken(cryptsetup.SetDiskNotInitialized) +} + // getInitialPassphrase retrieves the initial passphrase used on first boot. func (c *DiskEncryption) getInitialPassphrase() (string, error) { passphrase, err := afero.ReadFile(c.fs, initialKeyPath) diff --git a/bootstrapper/internal/initserver/BUILD.bazel b/bootstrapper/internal/initserver/BUILD.bazel index 009bb0594b..b1d5e66ba1 100644 --- a/bootstrapper/internal/initserver/BUILD.bazel +++ b/bootstrapper/internal/initserver/BUILD.bazel @@ -8,7 +8,6 @@ go_library( visibility = ["//bootstrapper:__subpackages__"], deps = [ "//bootstrapper/initproto", - "//bootstrapper/internal/diskencryption", "//bootstrapper/internal/journald", "//internal/atls", "//internal/attestation", diff --git a/bootstrapper/internal/initserver/initserver.go b/bootstrapper/internal/initserver/initserver.go index ff2e5e975f..a38bdbc8d9 100644 --- a/bootstrapper/internal/initserver/initserver.go +++ b/bootstrapper/internal/initserver/initserver.go @@ -30,7 +30,6 @@ import ( "time" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto" - "github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/journald" "github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/attestation" @@ -65,6 +64,7 @@ type Server struct { shutdownLock sync.RWMutex initSecretHash []byte + initFailure error kmsURI string @@ -76,7 +76,10 @@ type Server struct { } // New creates a new initialization server. -func New(ctx context.Context, lock locker, kube ClusterInitializer, issuer atls.Issuer, fh file.Handler, metadata MetadataAPI, log *slog.Logger) (*Server, error) { +func New( + ctx context.Context, lock locker, kube ClusterInitializer, issuer atls.Issuer, + disk encryptedDisk, fh file.Handler, metadata MetadataAPI, log *slog.Logger, +) (*Server, error) { log = log.WithGroup("initServer") initSecretHash, err := metadata.InitSecretHash(ctx) @@ -94,7 +97,7 @@ func New(ctx context.Context, lock locker, kube ClusterInitializer, issuer atls. server := &Server{ nodeLock: lock, - disk: diskencryption.New(), + disk: disk, initializer: kube, fileHandler: fh, issuer: issuer, @@ -123,11 +126,20 @@ func (s *Server) Serve(ip, port string, cleaner cleaner) error { } s.log.Info("Starting") - return s.grpcServer.Serve(lis) + err = s.grpcServer.Serve(lis) + + // If Init failed, we mark the disk for reset, so the node can restart the process + // In this case we don't care about any potential errors from the grpc server + if s.initFailure != nil { + s.log.Error("Fatal error during Init request", "error", s.initFailure) + return err + } + + return err } // Init initializes the cluster. -func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServer) (err error) { +func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServer) (retErr error) { // Acquire lock to prevent shutdown while Init is still running s.shutdownLock.RLock() defer s.shutdownLock.RUnlock() @@ -188,6 +200,9 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe // since we are bootstrapping a new one. // Any errors following this call will result in a failed node that may not join any cluster. s.cleaner.Clean() + defer func() { + s.initFailure = retErr + }() if err := s.setupDisk(stream.Context(), cloudKms); err != nil { if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "setting up disk: %s", err)); e != nil { diff --git a/bootstrapper/internal/initserver/initserver_test.go b/bootstrapper/internal/initserver/initserver_test.go index 77a0b0817f..84d0316d7c 100644 --- a/bootstrapper/internal/initserver/initserver_test.go +++ b/bootstrapper/internal/initserver/initserver_test.go @@ -67,7 +67,10 @@ func TestNew(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) - server, err := New(context.TODO(), newFakeLock(), &stubClusterInitializer{}, atls.NewFakeIssuer(variant.Dummy{}), fh, &tc.metadata, logger.NewTest(t)) + server, err := New( + context.TODO(), newFakeLock(), &stubClusterInitializer{}, atls.NewFakeIssuer(variant.Dummy{}), + &stubDisk{}, fh, &tc.metadata, logger.NewTest(t), + ) if tc.wantErr { assert.Error(err) return @@ -381,6 +384,10 @@ func (d *fakeDisk) UpdatePassphrase(passphrase string) error { return nil } +func (d *fakeDisk) MarkDiskForReset() error { + return nil +} + type stubDisk struct { openErr error uuid string @@ -402,6 +409,10 @@ func (d *stubDisk) UpdatePassphrase(string) error { return d.updatePassphraseErr } +func (d *stubDisk) MarkDiskForReset() error { + return nil +} + type stubClusterInitializer struct { initClusterKubeconfig []byte initClusterErr error diff --git a/bootstrapper/internal/joinclient/BUILD.bazel b/bootstrapper/internal/joinclient/BUILD.bazel index 3b8bf70b7f..687ffd250c 100644 --- a/bootstrapper/internal/joinclient/BUILD.bazel +++ b/bootstrapper/internal/joinclient/BUILD.bazel @@ -8,7 +8,6 @@ go_library( visibility = ["//bootstrapper:__subpackages__"], deps = [ "//bootstrapper/internal/certificate", - "//bootstrapper/internal/diskencryption", "//internal/attestation", "//internal/cloud/metadata", "//internal/constants", diff --git a/bootstrapper/internal/joinclient/joinclient.go b/bootstrapper/internal/joinclient/joinclient.go index 8f44fa1153..3e29443254 100644 --- a/bootstrapper/internal/joinclient/joinclient.go +++ b/bootstrapper/internal/joinclient/joinclient.go @@ -25,11 +25,9 @@ import ( "net" "path/filepath" "strconv" - "sync" "time" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/certificate" - "github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption" "github.com/edgelesssys/constellation/v2/internal/attestation" "github.com/edgelesssys/constellation/v2/internal/cloud/metadata" "github.com/edgelesssys/constellation/v2/internal/constants" @@ -69,21 +67,19 @@ type JoinClient struct { dialer grpcDialer joiner ClusterJoiner - cleaner cleaner metadataAPI MetadataAPI log *slog.Logger - mux sync.Mutex stopC chan struct{} stopDone chan struct{} } // New creates a new JoinClient. -func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, log *slog.Logger) *JoinClient { +func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, disk encryptedDisk, log *slog.Logger) *JoinClient { return &JoinClient{ nodeLock: lock, - disk: diskencryption.New(), + disk: disk, fileHandler: file.NewHandler(afero.NewOsFs()), timeout: timeout, joinTimeout: joinTimeout, @@ -93,99 +89,83 @@ func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, l joiner: joiner, metadataAPI: meta, log: log.WithGroup("join-client"), + + stopC: make(chan struct{}, 1), + stopDone: make(chan struct{}, 1), } } // Start starts the client routine. The client will make the needed API calls to join // the cluster with the role it receives from the metadata API. // After receiving the needed information, the node will join the cluster. -// Multiple calls of start on the same client won't start a second routine if there is -// already a routine running. -func (c *JoinClient) Start(cleaner cleaner) { - c.mux.Lock() - defer c.mux.Unlock() +func (c *JoinClient) Start(cleaner cleaner) error { + c.log.Info("Starting") + ticker := c.clock.NewTicker(c.interval) + defer ticker.Stop() + defer func() { c.stopDone <- struct{}{} }() + defer c.log.Info("Client stopped") - if c.stopC != nil { // daemon already running - return + diskUUID, err := c.getDiskUUID() + if err != nil { + c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID") + return err } + c.diskUUID = diskUUID - c.log.Info("Starting") - c.stopC = make(chan struct{}, 1) - c.stopDone = make(chan struct{}, 1) - c.cleaner = cleaner + for { + err := c.getNodeMetadata() + if err == nil { + c.log.With(slog.String("role", c.role.String()), slog.String("name", c.nodeName)).Info("Received own instance metadata") + break + } + c.log.With(slog.Any("error", err)).Error("Failed to retrieve instance metadata") - ticker := c.clock.NewTicker(c.interval) - go func() { - defer ticker.Stop() - defer func() { c.stopDone <- struct{}{} }() - defer c.log.Info("Client stopped") - - diskUUID, err := c.getDiskUUID() - if err != nil { - c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID") - return + c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") + select { + case <-c.stopC: + return nil + case <-ticker.C(): } - c.diskUUID = diskUUID - - for { - err := c.getNodeMetadata() - if err == nil { - c.log.With(slog.String("role", c.role.String()), slog.String("name", c.nodeName)).Info("Received own instance metadata") - break - } - c.log.With(slog.Any("error", err)).Error("Failed to retrieve instance metadata") - - c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") - select { - case <-c.stopC: - return - case <-ticker.C(): - } + } + + var ticket *joinproto.IssueJoinTicketResponse + var kubeletKey []byte + + for { + ticket, kubeletKey, err = c.tryJoinWithAvailableServices() + if err == nil { + c.log.Info("Successfully retrieved join ticket, starting Kubernetes node") + break } + c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints") - for { - err := c.tryJoinWithAvailableServices() - if err == nil { - c.log.Info("Joined successfully. Client is shutting down") - return - } else if isUnrecoverable(err) { - c.log.With(slog.Any("error", err)).Error("Unrecoverable error occurred") - // TODO(burgerdev): this should eventually lead to a full node reset - return - } - c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints") - - c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") - select { - case <-c.stopC: - return - case <-ticker.C(): - } + c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") + select { + case <-c.stopC: + return nil + case <-ticker.C(): } - }() + } + + if err := c.startNodeAndJoin(ticket, kubeletKey, cleaner); err != nil { + c.log.With(slog.Any("error", err)).Error("Failed to start node and join cluster") + return err + } + + return nil } // Stop stops the client and blocks until the client's routine is stopped. func (c *JoinClient) Stop() { - c.mux.Lock() - defer c.mux.Unlock() - - if c.stopC == nil { // daemon not running - return - } - c.log.Info("Stopping") c.stopC <- struct{}{} <-c.stopDone - c.stopC = nil - c.stopDone = nil - c.log.Info("Stopped") } -func (c *JoinClient) tryJoinWithAvailableServices() error { +func (c *JoinClient) tryJoinWithAvailableServices() (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) { ctx, cancel := c.timeoutCtx() defer cancel() @@ -193,46 +173,46 @@ func (c *JoinClient) tryJoinWithAvailableServices() error { endpoint, _, err := c.metadataAPI.GetLoadBalancerEndpoint(ctx) if err != nil { - return fmt.Errorf("failed to get load balancer endpoint: %w", err) + return nil, nil, fmt.Errorf("failed to get load balancer endpoint: %w", err) } endpoints = append(endpoints, endpoint) ips, err := c.getControlPlaneIPs(ctx) if err != nil { - return fmt.Errorf("failed to get control plane IPs: %w", err) + return nil, nil, fmt.Errorf("failed to get control plane IPs: %w", err) } endpoints = append(endpoints, ips...) if len(endpoints) == 0 { - return errors.New("no control plane IPs found") + return nil, nil, errors.New("no control plane IPs found") } + var joinErrs error for _, endpoint := range endpoints { - err = c.join(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort))) + ticket, kubeletKey, err := c.requestJoinTicket(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort))) if err == nil { - return nil - } - if isUnrecoverable(err) { - return err + return ticket, kubeletKey, nil } + + joinErrs = errors.Join(joinErrs, err) } - return err + return nil, nil, fmt.Errorf("trying to join on all endpoints %v: %w", endpoints, joinErrs) } -func (c *JoinClient) join(serviceEndpoint string) error { +func (c *JoinClient) requestJoinTicket(serviceEndpoint string) (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) { ctx, cancel := c.timeoutCtx() defer cancel() certificateRequest, kubeletKey, err := certificate.GetKubeletCertificateRequest(c.nodeName, c.validIPs) if err != nil { - return err + return nil, nil, err } conn, err := c.dialer.Dial(ctx, serviceEndpoint) if err != nil { c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Join service unreachable") - return fmt.Errorf("dialing join service endpoint: %w", err) + return nil, nil, fmt.Errorf("dialing join service endpoint: %w", err) } defer conn.Close() @@ -242,26 +222,19 @@ func (c *JoinClient) join(serviceEndpoint string) error { CertificateRequest: certificateRequest, IsControlPlane: c.role == role.ControlPlane, } - ticket, err := protoClient.IssueJoinTicket(ctx, req) + ticket, err = protoClient.IssueJoinTicket(ctx, req) if err != nil { c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Issuing join ticket failed") - return fmt.Errorf("issuing join ticket: %w", err) + return nil, nil, fmt.Errorf("issuing join ticket: %w", err) } - return c.startNodeAndJoin(ticket, kubeletKey) + return ticket, kubeletKey, err } -func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte) (retErr error) { +func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, cleaner cleaner) error { ctx, cancel := context.WithTimeout(context.Background(), c.joinTimeout) defer cancel() - // If an error occurs in this func, the client cannot continue. - defer func() { - if retErr != nil { - retErr = unrecoverableError{retErr} - } - }() - clusterID, err := attestation.DeriveClusterID(ticket.MeasurementSecret, ticket.MeasurementSalt) if err != nil { return err @@ -276,10 +249,11 @@ func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, // There is already a cluster initialization in progress on // this node, so there is no need to also join the cluster, // as the initializing node is automatically part of the cluster. - return errors.New("node is already being initialized") + c.log.Info("Node is already being initialized. Aborting join process.") + return nil } - c.cleaner.Clean() + cleaner.Clean() if err := c.updateDiskPassphrase(string(ticket.StateDiskKey)); err != nil { return fmt.Errorf("updating disk passphrase: %w", err) @@ -313,11 +287,12 @@ func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, // We currently cannot recover from any failure in this function. Joining the k8s cluster // sometimes fails transiently, and we don't want to brick the node because of that. - for i := 0; i < 3; i++ { + for i := range 3 { err = c.joiner.JoinCluster(ctx, btd, c.role, ticket.KubernetesComponents, c.log) - if err != nil { - c.log.Error("failed to join k8s cluster", "role", c.role, "attempt", i, "error", err) + if err == nil { + break } + c.log.Error("failed to join k8s cluster", "role", c.role, "attempt", i, "error", err) } if err != nil { return fmt.Errorf("joining Kubernetes cluster: %w", err) @@ -412,13 +387,6 @@ func (c *JoinClient) timeoutCtx() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), c.timeout) } -type unrecoverableError struct{ error } - -func isUnrecoverable(err error) bool { - _, ok := err.(unrecoverableError) - return ok -} - type grpcDialer interface { Dial(ctx context.Context, target string) (*grpc.ClientConn, error) } diff --git a/bootstrapper/internal/joinclient/joinclient_test.go b/bootstrapper/internal/joinclient/joinclient_test.go index d22ed4fb97..a93ed4b3fe 100644 --- a/bootstrapper/internal/joinclient/joinclient_test.go +++ b/bootstrapper/internal/joinclient/joinclient_test.go @@ -8,7 +8,6 @@ package joinclient import ( "context" - "errors" "log/slog" "net" "strconv" @@ -40,7 +39,6 @@ func TestMain(m *testing.M) { } func TestClient(t *testing.T) { - someErr := errors.New("failed") lockedLock := newFakeLock() aqcuiredLock, lockErr := lockedLock.TryLockOnce(nil) require.True(t, aqcuiredLock) @@ -67,9 +65,9 @@ func TestClient(t *testing.T) { "on worker: metadata self: errors occur": { role: role.Worker, apiAnswers: []any{ - selfAnswer{err: someErr}, - selfAnswer{err: someErr}, - selfAnswer{err: someErr}, + selfAnswer{err: assert.AnError}, + selfAnswer{err: assert.AnError}, + selfAnswer{err: assert.AnError}, selfAnswer{instance: workerSelf}, listAnswer{instances: peers}, issueJoinTicketAnswer{}, @@ -100,9 +98,9 @@ func TestClient(t *testing.T) { role: role.Worker, apiAnswers: []any{ selfAnswer{instance: workerSelf}, - listAnswer{err: someErr}, - listAnswer{err: someErr}, - listAnswer{err: someErr}, + listAnswer{err: assert.AnError}, + listAnswer{err: assert.AnError}, + listAnswer{err: assert.AnError}, listAnswer{instances: peers}, issueJoinTicketAnswer{}, }, @@ -133,9 +131,9 @@ func TestClient(t *testing.T) { apiAnswers: []any{ selfAnswer{instance: workerSelf}, listAnswer{instances: peers}, - issueJoinTicketAnswer{err: someErr}, + issueJoinTicketAnswer{err: assert.AnError}, listAnswer{instances: peers}, - issueJoinTicketAnswer{err: someErr}, + issueJoinTicketAnswer{err: assert.AnError}, listAnswer{instances: peers}, issueJoinTicketAnswer{}, }, @@ -150,9 +148,9 @@ func TestClient(t *testing.T) { apiAnswers: []any{ selfAnswer{instance: controlSelf}, listAnswer{instances: peers}, - issueJoinTicketAnswer{err: someErr}, + issueJoinTicketAnswer{err: assert.AnError}, listAnswer{instances: peers}, - issueJoinTicketAnswer{err: someErr}, + issueJoinTicketAnswer{err: assert.AnError}, listAnswer{instances: peers}, issueJoinTicketAnswer{}, }, @@ -169,7 +167,7 @@ func TestClient(t *testing.T) { listAnswer{instances: peers}, issueJoinTicketAnswer{}, }, - clusterJoiner: &stubClusterJoiner{numBadCalls: -1, joinClusterErr: someErr}, + clusterJoiner: &stubClusterJoiner{numBadCalls: -1, joinClusterErr: assert.AnError}, nodeLock: newFakeLock(), disk: &stubDisk{}, wantJoin: true, @@ -182,7 +180,7 @@ func TestClient(t *testing.T) { listAnswer{instances: peers}, issueJoinTicketAnswer{}, }, - clusterJoiner: &stubClusterJoiner{numBadCalls: 1, joinClusterErr: someErr}, + clusterJoiner: &stubClusterJoiner{numBadCalls: 1, joinClusterErr: assert.AnError}, nodeLock: newFakeLock(), disk: &stubDisk{}, wantJoin: true, @@ -205,13 +203,13 @@ func TestClient(t *testing.T) { role: role.ControlPlane, clusterJoiner: &stubClusterJoiner{}, nodeLock: newFakeLock(), - disk: &stubDisk{openErr: someErr}, + disk: &stubDisk{openErr: assert.AnError}, }, "on control plane: disk uuid fails": { role: role.ControlPlane, clusterJoiner: &stubClusterJoiner{}, nodeLock: newFakeLock(), - disk: &stubDisk{uuidErr: someErr}, + disk: &stubDisk{uuidErr: assert.AnError}, }, } @@ -237,6 +235,9 @@ func TestClient(t *testing.T) { metadataAPI: metadataAPI, clock: clock, log: logger.NewTest(t), + + stopC: make(chan struct{}, 1), + stopDone: make(chan struct{}, 1), } serverCreds := atlscredentials.New(nil, nil) @@ -248,7 +249,7 @@ func TestClient(t *testing.T) { go joinServer.Serve(listener) defer joinServer.GracefulStop() - client.Start(stubCleaner{}) + go func() { _ = client.Start(stubCleaner{}) }() for _, a := range tc.apiAnswers { switch a := a.(type) { @@ -281,78 +282,6 @@ func TestClient(t *testing.T) { } } -func TestClientConcurrentStartStop(t *testing.T) { - netDialer := testdialer.NewBufconnDialer() - dialer := dialer.New(nil, nil, netDialer) - client := &JoinClient{ - nodeLock: newFakeLock(), - timeout: 30 * time.Second, - interval: 30 * time.Second, - dialer: dialer, - disk: &stubDisk{}, - joiner: &stubClusterJoiner{}, - fileHandler: file.NewHandler(afero.NewMemMapFs()), - metadataAPI: &stubRepeaterMetadataAPI{}, - clock: testclock.NewFakeClock(time.Now()), - log: logger.NewTest(t), - } - - wg := sync.WaitGroup{} - - start := func() { - defer wg.Done() - client.Start(stubCleaner{}) - } - - stop := func() { - defer wg.Done() - client.Stop() - } - - wg.Add(10) - go stop() - go start() - go start() - go stop() - go stop() - go start() - go start() - go stop() - go stop() - go start() - wg.Wait() - - client.Stop() -} - -func TestIsUnrecoverable(t *testing.T) { - assert := assert.New(t) - - some := errors.New("failed") - unrec := unrecoverableError{some} - assert.True(isUnrecoverable(unrec)) - assert.False(isUnrecoverable(some)) -} - -type stubRepeaterMetadataAPI struct { - selfInstance metadata.InstanceMetadata - selfErr error - listInstances []metadata.InstanceMetadata - listErr error -} - -func (s *stubRepeaterMetadataAPI) Self(_ context.Context) (metadata.InstanceMetadata, error) { - return s.selfInstance, s.selfErr -} - -func (s *stubRepeaterMetadataAPI) List(_ context.Context) ([]metadata.InstanceMetadata, error) { - return s.listInstances, s.listErr -} - -func (s *stubRepeaterMetadataAPI) GetLoadBalancerEndpoint(_ context.Context) (string, string, error) { - return "", "", nil -} - type stubMetadataAPI struct { selfAnswerC chan selfAnswer listAnswerC chan listAnswer @@ -451,6 +380,10 @@ func (d *stubDisk) UpdatePassphrase(string) error { return d.updatePassphraseErr } +func (d *stubDisk) MarkDiskForReset() error { + return nil +} + type stubCleaner struct{} func (c stubCleaner) Clean() {} diff --git a/csi/test/mount_integration_test.go b/csi/test/mount_integration_test.go index 986636bf1a..36e9f7b15b 100644 --- a/csi/test/mount_integration_test.go +++ b/csi/test/mount_integration_test.go @@ -31,7 +31,7 @@ const ( deviceName string = "testDeviceName" ) -var toolsEnvs []string = []string{"CP", "DD", "RM", "FSCK_EXT4", "MKFS_EXT4", "BLKID", "FSCK", "MOUNT", "UMOUNT"} +var toolsEnvs = []string{"CP", "DD", "RM", "FSCK_EXT4", "MKFS_EXT4", "BLKID", "FSCK", "MOUNT", "UMOUNT"} // addToolsToPATH is used to update the PATH to contain necessary tool binaries for // coreutils, util-linux and ext4. diff --git a/disk-mapper/internal/test/integration_test.go b/disk-mapper/internal/test/integration_test.go index cc865c256b..5f0478839f 100644 --- a/disk-mapper/internal/test/integration_test.go +++ b/disk-mapper/internal/test/integration_test.go @@ -37,7 +37,7 @@ const ( var diskPath = flag.String("disk", "", "Path to the disk to use for the benchmark") -var toolsEnvs []string = []string{"DD", "RM"} +var toolsEnvs = []string{"DD", "RM"} // addToolsToPATH is used to update the PATH to contain necessary tool binaries for // coreutils.