Skip to content

Commit

Permalink
Bugfix: avoid partial config read/writes.
Browse files Browse the repository at this point in the history
The registry is written on AFTER the file. This means that failing to
write to the registry will cause the file to be updated but not the
registry. This is leaves the config in a bad state.

To solve it, we create a backup of the file, and restore it when the
registry write fails.

The same happens when reading.
  • Loading branch information
EduardGomezEscandell committed Oct 27, 2023
1 parent b3ac855 commit 1fb9d48
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions windows-agent/internal/config/config_marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@ type marshalHelper struct {
func (c *Config) load() (err error) {
defer decorate.OnError(&err, "could not load data for Config")

// Backup the config in case the registry fails.
// This avoids partial loads.
l := c.landscape
s := c.subscription

if err := c.loadFile(); err != nil {
return fmt.Errorf("could not load config from the chache file: %v", err)
}

if err := c.loadRegistry(); err != nil {
c.landscape = l
c.subscription = s
return fmt.Errorf("could not load config from the registry: %v", err)
}

Expand Down Expand Up @@ -97,16 +104,44 @@ func readFromRegistry(r Registry, key uintptr, field string) (string, error) {
func (c *Config) dump() (err error) {
defer decorate.OnError(&err, "could not store Config data")

var errs error
if err := c.dumpRegistry(); err != nil {
errs = errors.Join(errs, err)
// Backup the file in case the registry write fails.
// This avoids partial writes
restore, err := makeBackup(c.cachePath)
if err != nil {
return err
}

if err := c.dumpFile(); err != nil {
errs = errors.Join(errs, err)
return err
}

if err := c.dumpRegistry(); err != nil {
return errors.Join(err, restore())
}

return nil
}

func makeBackup(originalPath string) (func() error, error) {
backupPath := originalPath + ".backup"

// filterFsError is a helper to avoid having to do the errors.Is dance every time.
filterFsError := func(err error) error {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}

err := os.Rename(originalPath, backupPath)
if filterFsError(err) != nil {
return nil, fmt.Errorf("could not create backup: %v", err)
}

return errs
return func() error {
err := os.Rename(backupPath, originalPath)
return filterFsError(err)
}, nil
}

func (c *Config) dumpRegistry() error {
Expand Down

0 comments on commit 1fb9d48

Please sign in to comment.