diff --git a/windows-agent/internal/distros/database/database.go b/windows-agent/internal/distros/database/database.go index f226ed752..34d5e1b91 100644 --- a/windows-agent/internal/distros/database/database.go +++ b/windows-agent/internal/distros/database/database.go @@ -40,6 +40,10 @@ type DistroDB struct { ctx context.Context cancelCtx func() once sync.Once + + // Multiple distros starting at the same time can cause WSL (and the whole machine) to freeze up. + // This mutex is used to block multiple distros from starting at the same time. + distroStartMu sync.Mutex } // New creates a database and populates it with data in the file located @@ -147,7 +151,7 @@ func (db *DistroDB) GetDistroAndUpdateProperties(ctx context.Context, name strin if !found { log.Debugf(ctx, "Cache miss, creating %q and adding it to the database", name) - d, err := distro.New(db.ctx, name, props, db.storageDir, distro.WithProvisioning(db.provisioning)) + d, err := distro.New(db.ctx, name, props, db.storageDir, &db.distroStartMu, distro.WithProvisioning(db.provisioning)) if err != nil { return nil, err } @@ -165,7 +169,7 @@ func (db *DistroDB) GetDistroAndUpdateProperties(ctx context.Context, name strin go d.Cleanup(ctx) delete(db.distros, normalizedName) - d, err := distro.New(db.ctx, name, props, db.storageDir, distro.WithProvisioning(db.provisioning)) + d, err := distro.New(db.ctx, name, props, db.storageDir, &db.distroStartMu, distro.WithProvisioning(db.provisioning)) if err != nil { return nil, err } @@ -252,7 +256,7 @@ func (db *DistroDB) load(ctx context.Context) error { // Initializing distros into database db.distros = make(map[string]*distro.Distro, len(distros)) for _, inert := range distros { - d, err := inert.newDistro(ctx, db.storageDir) + d, err := inert.newDistro(ctx, db.storageDir, &db.distroStartMu) if err != nil { log.Warningf(ctx, "Read invalid distro from database: %#+v", inert) continue diff --git a/windows-agent/internal/distros/database/export_test.go b/windows-agent/internal/distros/database/export_test.go index 8adb6206e..2c159ed90 100644 --- a/windows-agent/internal/distros/database/export_test.go +++ b/windows-agent/internal/distros/database/export_test.go @@ -2,6 +2,7 @@ package database import ( "context" + "sync" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/distro" ) @@ -9,8 +10,8 @@ import ( type SerializableDistro = serializableDistro // NewDistro is a wrapper around newDistro so as to make it accessible to tests. -func (in SerializableDistro) NewDistro(ctx context.Context, storageDir string) (*distro.Distro, error) { - return in.newDistro(ctx, storageDir) +func (in SerializableDistro) NewDistro(ctx context.Context, storageDir string, startupMu *sync.Mutex) (*distro.Distro, error) { + return in.newDistro(ctx, storageDir, startupMu) } // NewSerializableDistro is a wrapper around newSerializableDistro so as to make it accessible to tests. diff --git a/windows-agent/internal/distros/database/serializable_distro.go b/windows-agent/internal/distros/database/serializable_distro.go index daa893f97..44e4a9a07 100644 --- a/windows-agent/internal/distros/database/serializable_distro.go +++ b/windows-agent/internal/distros/database/serializable_distro.go @@ -2,6 +2,7 @@ package database import ( "context" + "sync" "github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/distro" "github.com/google/uuid" @@ -20,12 +21,12 @@ type serializableDistro struct { // newDistro calls distro.New with the name, GUID and properties specified // in its inert counterpart. -func (in serializableDistro) newDistro(ctx context.Context, storageDir string) (*distro.Distro, error) { +func (in serializableDistro) newDistro(ctx context.Context, storageDir string, startupMu *sync.Mutex) (*distro.Distro, error) { GUID, err := uuid.Parse(in.GUID) if err != nil { return nil, err } - return distro.New(ctx, in.Name, in.Properties, storageDir, distro.WithGUID(GUID)) + return distro.New(ctx, in.Name, in.Properties, storageDir, startupMu, distro.WithGUID(GUID)) } // newSerializableDistro takes the information in distro.Distro relevant to the database diff --git a/windows-agent/internal/distros/database/serializable_distro_test.go b/windows-agent/internal/distros/database/serializable_distro_test.go index 71c14d202..03440845b 100644 --- a/windows-agent/internal/distros/database/serializable_distro_test.go +++ b/windows-agent/internal/distros/database/serializable_distro_test.go @@ -2,6 +2,7 @@ package database_test import ( "context" + "sync" "testing" "github.com/canonical/ubuntu-pro-for-windows/common/wsltestutils" @@ -109,7 +110,10 @@ func TestSerializableDistroNewDistro(t *testing.T) { GUID: tc.guid, } - d, err := s.NewDistro(ctx, t.TempDir()) + // This distro is never started, so no need for any global mutex + var mu sync.Mutex + + d, err := s.NewDistro(ctx, t.TempDir(), &mu) if err == nil { defer d.Cleanup(context.Background()) } @@ -140,7 +144,10 @@ func TestNewSerializableDistro(t *testing.T) { Hostname: "NegativeMachine", } - d, err := distro.New(ctx, registeredDistro, props, t.TempDir()) + // This distro is never started, so no need for any global mutex + var mu sync.Mutex + + d, err := distro.New(ctx, registeredDistro, props, t.TempDir(), &mu) require.NoError(t, err, "Setup: distro New() should return no error") s := database.NewSerializableDistro(d) diff --git a/windows-agent/internal/distros/distro/distro.go b/windows-agent/internal/distros/distro/distro.go index 177da3ba8..73a95ca03 100644 --- a/windows-agent/internal/distros/distro/distro.go +++ b/windows-agent/internal/distros/distro/distro.go @@ -4,6 +4,7 @@ package distro import ( "context" + "errors" "fmt" "os" "sync" @@ -91,7 +92,7 @@ func WithProvisioning(c worker.Provisioning) Option { // // - To avoid the latter check, you can pass a default-constructed identity.GUID. In that // case, the distro will be created with its currently registered GUID. -func New(ctx context.Context, name string, props Properties, storageDir string, args ...Option) (distro *Distro, err error) { +func New(ctx context.Context, name string, props Properties, storageDir string, startupMu *sync.Mutex, args ...Option) (distro *Distro, err error) { decorate.OnError(&err, "could not initialize distro %q", name) var nilGUID uuid.UUID @@ -129,11 +130,16 @@ func New(ctx context.Context, name string, props Properties, storageDir string, } } + if startupMu == nil { + return nil, errors.New("startup mutex must not be nil") + } + distro = &Distro{ identity: id, properties: props, stateManager: &stateManager{ distroIdentity: id, + startupMu: startupMu, }, } diff --git a/windows-agent/internal/distros/distro/distro_state.go b/windows-agent/internal/distros/distro/distro_state.go index 83d782ca0..84660a30f 100644 --- a/windows-agent/internal/distros/distro/distro_state.go +++ b/windows-agent/internal/distros/distro/distro_state.go @@ -16,11 +16,18 @@ import ( // The distro is guaranteed to be running so long as the counter is above 0. This counter can // be increased or decreased on demand, and is thread-safe. type stateManager struct { + distroIdentity identity + refcount uint32 cancel func() - mu sync.Mutex - distroIdentity identity + // mu is a mutex for the refcount and the cancel func. We cannot use an atomic because increasing + // or decreasing the count entails more operations than simply adding one to this number. + mu sync.Mutex + + // startupMu protects against multiple distros starting at the same time. This could cause WSL + // (and the whole machine) to freeze up. + startupMu *sync.Mutex } // state returns the state of the WSL distro, as implemeted by GoWSL. @@ -111,6 +118,9 @@ func (m *stateManager) reset() { // // The distro will be running by the time keepAwake returns. func (m *stateManager) keepAwake(ctx context.Context) (err error) { + m.startupMu.Lock() + defer m.startupMu.Unlock() + // Wake up distro if err := touchdistro.Touch(ctx, m.distroIdentity.Name); err != nil { return fmt.Errorf("could not wake distro up: %v", err) diff --git a/windows-agent/internal/distros/distro/distro_test.go b/windows-agent/internal/distros/distro/distro_test.go index 7fd6db76a..20895016b 100644 --- a/windows-agent/internal/distros/distro/distro_test.go +++ b/windows-agent/internal/distros/distro/distro_test.go @@ -5,6 +5,7 @@ import ( "errors" "os" "path/filepath" + "sync" "testing" "time" @@ -28,6 +29,25 @@ func TestMain(m *testing.M) { defer os.Exit(exit) } +// globalStartupMu protects against multiple distros starting at the same time. +var globalStartupMu sync.Mutex + +// startupMutex exists so that all distro tests share the same startup mutex. +// This mutex prevents multiple distros from starting at the same time, which +// could freeze the machine. +// +// When a mock WSL is used, this concern does not exist so we provide a new +// mutex for every test so they can run in parallel without interference. +func startupMutex() *sync.Mutex { + if wsl.MockAvailable() { + // No real distros: use a different mutex every test + return &sync.Mutex{} + } + + // Real distros: use a the same mutex for all tests + return &globalStartupMu +} + func TestNew(t *testing.T) { ctx := context.Background() if wsl.MockAvailable() { @@ -52,6 +72,7 @@ func TestNew(t *testing.T) { withGUID string preventWorkDirCreation bool withProvisioning bool + nilMutex bool wantErr bool wantErrType error @@ -67,6 +88,7 @@ func TestNew(t *testing.T) { "Error when the distro is not registered": {distro: nonRegisteredDistro, wantErr: true, wantErrType: &distro.NotValidError{}}, "Error when the distro is not registered, but the GUID is": {distro: nonRegisteredDistro, withGUID: registeredGUID, wantErr: true, wantErrType: &distro.NotValidError{}}, "Error when neither the distro nor the GUID are registered": {distro: nonRegisteredDistro, withGUID: fakeGUID, wantErr: true, wantErrType: &distro.NotValidError{}}, + "Error when the startup mutex is nil": {distro: registeredDistro, nilMutex: true, wantErr: true}, } for name, tc := range testCases { @@ -93,7 +115,12 @@ func TestNew(t *testing.T) { require.NoError(t, err, "Setup: could not write file to interfere with distro's MkDir") } - d, err = distro.New(ctx, tc.distro, props, workDir, args...) + mu := startupMutex() + if tc.nilMutex { + mu = nil + } + + d, err = distro.New(ctx, tc.distro, props, workDir, mu, args...) defer d.Cleanup(context.Background()) if tc.wantErr { @@ -123,7 +150,8 @@ func TestString(t *testing.T) { GUID, err := uuid.Parse(guid) require.NoError(t, err, "Setup: could not parse guid %s: %v", GUID, err) - d, err := distro.New(ctx, name, distro.Properties{}, t.TempDir(), distro.WithGUID(GUID)) + + d, err := distro.New(ctx, name, distro.Properties{}, t.TempDir(), startupMutex(), distro.WithGUID(GUID)) defer d.Cleanup(context.Background()) require.NoError(t, err, "Setup: unexpected error in distro.New") @@ -163,7 +191,7 @@ func TestIsValid(t *testing.T) { tc := tc t.Run(name, func(t *testing.T) { // Create an always valid distro - d, err := distro.New(ctx, distro1, distro.Properties{}, t.TempDir()) + d, err := distro.New(ctx, distro1, distro.Properties{}, t.TempDir(), startupMutex()) defer d.Cleanup(context.Background()) require.NoError(t, err, "Setup: distro New() should return no errors") @@ -219,7 +247,7 @@ func TestSetProperties(t *testing.T) { } dname, _ := wsltestutils.RegisterDistro(t, ctx, false) - d, err := distro.New(ctx, dname, props1, t.TempDir()) + d, err := distro.New(ctx, dname, props1, t.TempDir(), startupMutex()) require.NoError(t, err, "Setup: distro New should return no errors") p := props2 @@ -295,7 +323,7 @@ func TestLockReleaseAwake(t *testing.T) { distroName, _ := wsltestutils.RegisterDistro(t, ctx, true) - d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir()) + d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), startupMutex()) defer d.Cleanup(context.Background()) require.NoError(t, err, "Setup: distro New should return no error") @@ -404,6 +432,57 @@ func TestLockReleaseAwake(t *testing.T) { } } +func TestNoSimultaneousStartups(t *testing.T) { + t.Parallel() + + if !wsl.MockAvailable() { + t.Skip("Skipped without mocks to avoid messing with the global mutex") + } + + ctx := wsl.WithMock(context.Background(), wslmock.New()) + var startupMu sync.Mutex + + distroName, _ := wsltestutils.RegisterDistro(t, ctx, true) + d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), &startupMu) + defer d.Cleanup(context.Background()) + require.NoError(t, err, "Setup: distro New should return no error") + + wsltestutils.TerminateDistro(t, ctx, distroName) + + // Lock the startup mutex to pretend some other distro is starting up + const lockAwakeMaxTime = 20 * time.Second + ch := make(chan error) + + func() { + startupMu.Lock() + defer startupMu.Unlock() + + go func() { + // We send the error to be asserted in the main goroutine because + // failed assertions outside the test goroutine cause panics. + ch <- d.LockAwake() + close(ch) + }() + + time.Sleep(lockAwakeMaxTime) + state := wsltestutils.DistroState(t, ctx, distroName) + require.Equal(t, "Stopped", state, "Distro should not start while the mutex is locked") + }() + + // The startup mutex has been released to pretend some other distro finished starting up + + select { + case <-time.After(lockAwakeMaxTime): + require.Fail(t, "LockAwake should have returned after releasing the startup mutex") + case err := <-ch: + require.NoError(t, err, "LockAwake should return no error") + break + } + + state := wsltestutils.DistroState(t, ctx, distroName) + require.Equal(t, "Running", state, "Distro should start after the mutex is released") +} + func TestState(t *testing.T) { if wsl.MockAvailable() { t.Parallel() @@ -437,7 +516,7 @@ func TestState(t *testing.T) { } distroName, _ := wsltestutils.RegisterDistro(t, ctx, true) - d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir()) + d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), startupMutex()) require.NoError(t, err, "Setup: distro New should return no errors") gowslDistro := wsl.NewDistro(ctx, distroName) @@ -502,6 +581,7 @@ func TestWorkerConstruction(t *testing.T) { distroName, distro.Properties{}, workDir, + startupMutex(), distro.WithTaskProcessingContext(ctx), distro.WithProvisioning(provisioning), withMockWorker) @@ -533,7 +613,7 @@ func TestInvalidateIdempotent(t *testing.T) { inj, w := mockWorkerInjector(false) - d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), inj) + d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), &globalStartupMu, inj) defer d.Cleanup(context.Background()) require.NoError(t, err, "Setup: distro New should return no error") @@ -598,7 +678,7 @@ func TestWorkerWrappers(t *testing.T) { inj, w := mockWorkerInjector(false) - d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), inj) + d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), startupMutex(), inj) defer d.Cleanup(context.Background()) require.NoError(t, err, "Setup: distro New should return no error") @@ -690,7 +770,7 @@ func TestUninstall(t *testing.T) { name, _ := wsltestutils.RegisterDistro(t, ctx, false) - d, err := distro.New(ctx, name, distro.Properties{}, t.TempDir()) + d, err := distro.New(ctx, name, distro.Properties{}, t.TempDir(), startupMutex()) require.NoError(t, err, "Setup: distro New should return no errors") if tc.unregisterDistro {