Skip to content

Commit

Permalink
feat(windows-agent): Prevent multiple distros from starting at the sa…
Browse files Browse the repository at this point in the history
…me time (#421)

Prevents multiple distros from starting at the same time, so that the
machine does not freeze up, for example, when a pro token is introduced.

Also relevant during startup, where all the failed tasks are retried.

---

UDENG-1716
  • Loading branch information
EduardGomezEscandell authored Dec 11, 2023
2 parents 97e52e0 + b4a834e commit a7e5bdd
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 21 deletions.
10 changes: 7 additions & 3 deletions windows-agent/internal/distros/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ type DistroDB struct {
ctx context.Context
cancelCtx func()
once sync.Once

// Multiple distros starting at the same time can cause WSL (and the whole machine) to freeze up.
// This mutex is used to block multiple distros from starting at the same time.
distroStartMu sync.Mutex
}

// New creates a database and populates it with data in the file located
Expand Down Expand Up @@ -147,7 +151,7 @@ func (db *DistroDB) GetDistroAndUpdateProperties(ctx context.Context, name strin
if !found {
log.Debugf(ctx, "Cache miss, creating %q and adding it to the database", name)

d, err := distro.New(db.ctx, name, props, db.storageDir, distro.WithProvisioning(db.provisioning))
d, err := distro.New(db.ctx, name, props, db.storageDir, &db.distroStartMu, distro.WithProvisioning(db.provisioning))
if err != nil {
return nil, err
}
Expand All @@ -165,7 +169,7 @@ func (db *DistroDB) GetDistroAndUpdateProperties(ctx context.Context, name strin
go d.Cleanup(ctx)
delete(db.distros, normalizedName)

d, err := distro.New(db.ctx, name, props, db.storageDir, distro.WithProvisioning(db.provisioning))
d, err := distro.New(db.ctx, name, props, db.storageDir, &db.distroStartMu, distro.WithProvisioning(db.provisioning))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -252,7 +256,7 @@ func (db *DistroDB) load(ctx context.Context) error {
// Initializing distros into database
db.distros = make(map[string]*distro.Distro, len(distros))
for _, inert := range distros {
d, err := inert.newDistro(ctx, db.storageDir)
d, err := inert.newDistro(ctx, db.storageDir, &db.distroStartMu)
if err != nil {
log.Warningf(ctx, "Read invalid distro from database: %#+v", inert)
continue
Expand Down
5 changes: 3 additions & 2 deletions windows-agent/internal/distros/database/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package database

import (
"context"
"sync"

"github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/distro"
)

type SerializableDistro = serializableDistro

// NewDistro is a wrapper around newDistro so as to make it accessible to tests.
func (in SerializableDistro) NewDistro(ctx context.Context, storageDir string) (*distro.Distro, error) {
return in.newDistro(ctx, storageDir)
func (in SerializableDistro) NewDistro(ctx context.Context, storageDir string, startupMu *sync.Mutex) (*distro.Distro, error) {
return in.newDistro(ctx, storageDir, startupMu)
}

// NewSerializableDistro is a wrapper around newSerializableDistro so as to make it accessible to tests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database

import (
"context"
"sync"

"github.com/canonical/ubuntu-pro-for-windows/windows-agent/internal/distros/distro"
"github.com/google/uuid"
Expand All @@ -20,12 +21,12 @@ type serializableDistro struct {

// newDistro calls distro.New with the name, GUID and properties specified
// in its inert counterpart.
func (in serializableDistro) newDistro(ctx context.Context, storageDir string) (*distro.Distro, error) {
func (in serializableDistro) newDistro(ctx context.Context, storageDir string, startupMu *sync.Mutex) (*distro.Distro, error) {
GUID, err := uuid.Parse(in.GUID)
if err != nil {
return nil, err
}
return distro.New(ctx, in.Name, in.Properties, storageDir, distro.WithGUID(GUID))
return distro.New(ctx, in.Name, in.Properties, storageDir, startupMu, distro.WithGUID(GUID))
}

// newSerializableDistro takes the information in distro.Distro relevant to the database
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database_test

import (
"context"
"sync"
"testing"

"github.com/canonical/ubuntu-pro-for-windows/common/wsltestutils"
Expand Down Expand Up @@ -109,7 +110,10 @@ func TestSerializableDistroNewDistro(t *testing.T) {
GUID: tc.guid,
}

d, err := s.NewDistro(ctx, t.TempDir())
// This distro is never started, so no need for any global mutex
var mu sync.Mutex

d, err := s.NewDistro(ctx, t.TempDir(), &mu)
if err == nil {
defer d.Cleanup(context.Background())
}
Expand Down Expand Up @@ -140,7 +144,10 @@ func TestNewSerializableDistro(t *testing.T) {
Hostname: "NegativeMachine",
}

d, err := distro.New(ctx, registeredDistro, props, t.TempDir())
// This distro is never started, so no need for any global mutex
var mu sync.Mutex

d, err := distro.New(ctx, registeredDistro, props, t.TempDir(), &mu)
require.NoError(t, err, "Setup: distro New() should return no error")

s := database.NewSerializableDistro(d)
Expand Down
8 changes: 7 additions & 1 deletion windows-agent/internal/distros/distro/distro.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package distro

import (
"context"
"errors"
"fmt"
"os"
"sync"
Expand Down Expand Up @@ -91,7 +92,7 @@ func WithProvisioning(c worker.Provisioning) Option {
//
// - To avoid the latter check, you can pass a default-constructed identity.GUID. In that
// case, the distro will be created with its currently registered GUID.
func New(ctx context.Context, name string, props Properties, storageDir string, args ...Option) (distro *Distro, err error) {
func New(ctx context.Context, name string, props Properties, storageDir string, startupMu *sync.Mutex, args ...Option) (distro *Distro, err error) {
decorate.OnError(&err, "could not initialize distro %q", name)

var nilGUID uuid.UUID
Expand Down Expand Up @@ -129,11 +130,16 @@ func New(ctx context.Context, name string, props Properties, storageDir string,
}
}

if startupMu == nil {
return nil, errors.New("startup mutex must not be nil")
}

distro = &Distro{
identity: id,
properties: props,
stateManager: &stateManager{
distroIdentity: id,
startupMu: startupMu,
},
}

Expand Down
14 changes: 12 additions & 2 deletions windows-agent/internal/distros/distro/distro_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@ import (
// The distro is guaranteed to be running so long as the counter is above 0. This counter can
// be increased or decreased on demand, and is thread-safe.
type stateManager struct {
distroIdentity identity

refcount uint32
cancel func()
mu sync.Mutex

distroIdentity identity
// mu is a mutex for the refcount and the cancel func. We cannot use an atomic because increasing
// or decreasing the count entails more operations than simply adding one to this number.
mu sync.Mutex

// startupMu protects against multiple distros starting at the same time. This could cause WSL
// (and the whole machine) to freeze up.
startupMu *sync.Mutex
}

// state returns the state of the WSL distro, as implemeted by GoWSL.
Expand Down Expand Up @@ -111,6 +118,9 @@ func (m *stateManager) reset() {
//
// The distro will be running by the time keepAwake returns.
func (m *stateManager) keepAwake(ctx context.Context) (err error) {
m.startupMu.Lock()
defer m.startupMu.Unlock()

// Wake up distro
if err := touchdistro.Touch(ctx, m.distroIdentity.Name); err != nil {
return fmt.Errorf("could not wake distro up: %v", err)
Expand Down
98 changes: 89 additions & 9 deletions windows-agent/internal/distros/distro/distro_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"os"
"path/filepath"
"sync"
"testing"
"time"

Expand All @@ -28,6 +29,25 @@ func TestMain(m *testing.M) {
defer os.Exit(exit)
}

// globalStartupMu protects against multiple distros starting at the same time.
var globalStartupMu sync.Mutex

// startupMutex exists so that all distro tests share the same startup mutex.
// This mutex prevents multiple distros from starting at the same time, which
// could freeze the machine.
//
// When a mock WSL is used, this concern does not exist so we provide a new
// mutex for every test so they can run in parallel without interference.
func startupMutex() *sync.Mutex {
if wsl.MockAvailable() {
// No real distros: use a different mutex every test
return &sync.Mutex{}
}

// Real distros: use a the same mutex for all tests
return &globalStartupMu
}

func TestNew(t *testing.T) {
ctx := context.Background()
if wsl.MockAvailable() {
Expand All @@ -52,6 +72,7 @@ func TestNew(t *testing.T) {
withGUID string
preventWorkDirCreation bool
withProvisioning bool
nilMutex bool

wantErr bool
wantErrType error
Expand All @@ -67,6 +88,7 @@ func TestNew(t *testing.T) {
"Error when the distro is not registered": {distro: nonRegisteredDistro, wantErr: true, wantErrType: &distro.NotValidError{}},
"Error when the distro is not registered, but the GUID is": {distro: nonRegisteredDistro, withGUID: registeredGUID, wantErr: true, wantErrType: &distro.NotValidError{}},
"Error when neither the distro nor the GUID are registered": {distro: nonRegisteredDistro, withGUID: fakeGUID, wantErr: true, wantErrType: &distro.NotValidError{}},
"Error when the startup mutex is nil": {distro: registeredDistro, nilMutex: true, wantErr: true},
}

for name, tc := range testCases {
Expand All @@ -93,7 +115,12 @@ func TestNew(t *testing.T) {
require.NoError(t, err, "Setup: could not write file to interfere with distro's MkDir")
}

d, err = distro.New(ctx, tc.distro, props, workDir, args...)
mu := startupMutex()
if tc.nilMutex {
mu = nil
}

d, err = distro.New(ctx, tc.distro, props, workDir, mu, args...)
defer d.Cleanup(context.Background())

if tc.wantErr {
Expand Down Expand Up @@ -123,7 +150,8 @@ func TestString(t *testing.T) {

GUID, err := uuid.Parse(guid)
require.NoError(t, err, "Setup: could not parse guid %s: %v", GUID, err)
d, err := distro.New(ctx, name, distro.Properties{}, t.TempDir(), distro.WithGUID(GUID))

d, err := distro.New(ctx, name, distro.Properties{}, t.TempDir(), startupMutex(), distro.WithGUID(GUID))
defer d.Cleanup(context.Background())

require.NoError(t, err, "Setup: unexpected error in distro.New")
Expand Down Expand Up @@ -163,7 +191,7 @@ func TestIsValid(t *testing.T) {
tc := tc
t.Run(name, func(t *testing.T) {
// Create an always valid distro
d, err := distro.New(ctx, distro1, distro.Properties{}, t.TempDir())
d, err := distro.New(ctx, distro1, distro.Properties{}, t.TempDir(), startupMutex())
defer d.Cleanup(context.Background())

require.NoError(t, err, "Setup: distro New() should return no errors")
Expand Down Expand Up @@ -219,7 +247,7 @@ func TestSetProperties(t *testing.T) {
}

dname, _ := wsltestutils.RegisterDistro(t, ctx, false)
d, err := distro.New(ctx, dname, props1, t.TempDir())
d, err := distro.New(ctx, dname, props1, t.TempDir(), startupMutex())
require.NoError(t, err, "Setup: distro New should return no errors")

p := props2
Expand Down Expand Up @@ -295,7 +323,7 @@ func TestLockReleaseAwake(t *testing.T) {

distroName, _ := wsltestutils.RegisterDistro(t, ctx, true)

d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir())
d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), startupMutex())
defer d.Cleanup(context.Background())

require.NoError(t, err, "Setup: distro New should return no error")
Expand Down Expand Up @@ -404,6 +432,57 @@ func TestLockReleaseAwake(t *testing.T) {
}
}

func TestNoSimultaneousStartups(t *testing.T) {
t.Parallel()

if !wsl.MockAvailable() {
t.Skip("Skipped without mocks to avoid messing with the global mutex")
}

ctx := wsl.WithMock(context.Background(), wslmock.New())
var startupMu sync.Mutex

distroName, _ := wsltestutils.RegisterDistro(t, ctx, true)
d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), &startupMu)
defer d.Cleanup(context.Background())
require.NoError(t, err, "Setup: distro New should return no error")

wsltestutils.TerminateDistro(t, ctx, distroName)

// Lock the startup mutex to pretend some other distro is starting up
const lockAwakeMaxTime = 20 * time.Second
ch := make(chan error)

func() {
startupMu.Lock()
defer startupMu.Unlock()

go func() {
// We send the error to be asserted in the main goroutine because
// failed assertions outside the test goroutine cause panics.
ch <- d.LockAwake()
close(ch)
}()

time.Sleep(lockAwakeMaxTime)
state := wsltestutils.DistroState(t, ctx, distroName)
require.Equal(t, "Stopped", state, "Distro should not start while the mutex is locked")
}()

// The startup mutex has been released to pretend some other distro finished starting up

select {
case <-time.After(lockAwakeMaxTime):
require.Fail(t, "LockAwake should have returned after releasing the startup mutex")
case err := <-ch:
require.NoError(t, err, "LockAwake should return no error")
break
}

state := wsltestutils.DistroState(t, ctx, distroName)
require.Equal(t, "Running", state, "Distro should start after the mutex is released")
}

func TestState(t *testing.T) {
if wsl.MockAvailable() {
t.Parallel()
Expand Down Expand Up @@ -437,7 +516,7 @@ func TestState(t *testing.T) {
}

distroName, _ := wsltestutils.RegisterDistro(t, ctx, true)
d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir())
d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), startupMutex())
require.NoError(t, err, "Setup: distro New should return no errors")

gowslDistro := wsl.NewDistro(ctx, distroName)
Expand Down Expand Up @@ -502,6 +581,7 @@ func TestWorkerConstruction(t *testing.T) {
distroName,
distro.Properties{},
workDir,
startupMutex(),
distro.WithTaskProcessingContext(ctx),
distro.WithProvisioning(provisioning),
withMockWorker)
Expand Down Expand Up @@ -533,7 +613,7 @@ func TestInvalidateIdempotent(t *testing.T) {

inj, w := mockWorkerInjector(false)

d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), inj)
d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), &globalStartupMu, inj)
defer d.Cleanup(context.Background())
require.NoError(t, err, "Setup: distro New should return no error")

Expand Down Expand Up @@ -598,7 +678,7 @@ func TestWorkerWrappers(t *testing.T) {

inj, w := mockWorkerInjector(false)

d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), inj)
d, err := distro.New(ctx, distroName, distro.Properties{}, t.TempDir(), startupMutex(), inj)
defer d.Cleanup(context.Background())
require.NoError(t, err, "Setup: distro New should return no error")

Expand Down Expand Up @@ -690,7 +770,7 @@ func TestUninstall(t *testing.T) {

name, _ := wsltestutils.RegisterDistro(t, ctx, false)

d, err := distro.New(ctx, name, distro.Properties{}, t.TempDir())
d, err := distro.New(ctx, name, distro.Properties{}, t.TempDir(), startupMutex())
require.NoError(t, err, "Setup: distro New should return no errors")

if tc.unregisterDistro {
Expand Down

0 comments on commit a7e5bdd

Please sign in to comment.