diff --git a/windows-agent/cmd/ubuntu-pro-agent/agent/agent.go b/windows-agent/cmd/ubuntu-pro-agent/agent/agent.go index 5feb0a158..7d7bd5809 100644 --- a/windows-agent/cmd/ubuntu-pro-agent/agent/agent.go +++ b/windows-agent/cmd/ubuntu-pro-agent/agent/agent.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "os" "path/filepath" "runtime" @@ -80,12 +81,19 @@ func New(o ...option) *App { } setVerboseMode(a.config.Verbosity) - log.Debug(context.Background(), "Debug mode is enabled") return nil }, RunE: func(cmd *cobra.Command, args []string) error { - return a.serve(o...) + ctx := context.Background() + + cleanup, err := a.setUpLogger(ctx) + if err != nil { + log.Warningf(ctx, "could not set logger output: %v", err) + } + defer cleanup() + + return a.serve(ctx, o...) }, // We display usage error ourselves SilenceErrors: true, @@ -96,14 +104,13 @@ func New(o ...option) *App { // subcommands a.installVersion() + a.installClean() return &a } // serve creates new GRPC services and listen on a TCP socket. This call is blocking until we quit it. -func (a *App) serve(args ...option) error { - ctx := context.TODO() - +func (a *App) serve(ctx context.Context, args ...option) error { var opt options for _, f := range args { f(&opt) @@ -222,7 +229,7 @@ func (a *App) publicDir(opts options) (string, error) { opts.publicDir = filepath.Join(homeDir, common.UserProfileDir) } - if err := os.MkdirAll(opts.publicDir, 0600); err != nil { + if err := os.MkdirAll(opts.publicDir, 0700); err != nil { return "", fmt.Errorf("could not create public dir %s: %v", opts.publicDir, err) } @@ -240,9 +247,46 @@ func (a *App) privateDir(opts options) (string, error) { opts.privateDir = filepath.Join(localAppData, common.LocalAppDataDir) } - if err := os.MkdirAll(opts.privateDir, 0600); err != nil { + if err := os.MkdirAll(opts.privateDir, 0700); err != nil { return "", fmt.Errorf("could not create private dir %s: %v", opts.privateDir, err) } return opts.privateDir, nil } + +func (a *App) setUpLogger(ctx context.Context) (func(), error) { + noop := func() {} + + logrus.SetFormatter(&logrus.TextFormatter{ + DisableQuote: true, + }) + + publicDir, err := a.PublicDir() + if err != nil { + return noop, err + } + + logFile := filepath.Join(publicDir, "log") + + // Move old log file + oldLogFile := filepath.Join(publicDir, "log.old") + err = os.Rename(logFile, oldLogFile) + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Warningf(ctx, "Could not archive previous log file: %v", err) + } + + f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE, 0600) + if err != nil { + return noop, fmt.Errorf("could not open log file: %v", err) + } + + // Write both to file and to Stdout. The latter is useful for local development. + w := io.MultiWriter(f, os.Stdout) + logrus.SetOutput(w) + + fmt.Fprintf(f, "\n======= STARTUP =======\n") + log.Infof(ctx, "Version: %s", consts.Version) + log.Debug(ctx, "Debug mode is enabled") + + return func() { _ = f.Close() }, nil +} diff --git a/windows-agent/cmd/ubuntu-pro-agent/agent/agent_test.go b/windows-agent/cmd/ubuntu-pro-agent/agent/agent_test.go index 7519ad23f..5bc92d0f9 100644 --- a/windows-agent/cmd/ubuntu-pro-agent/agent/agent_test.go +++ b/windows-agent/cmd/ubuntu-pro-agent/agent/agent_test.go @@ -225,6 +225,174 @@ func TestPublicDir(t *testing.T) { } } +func TestLogs(t *testing.T) { + // Not parallel because we modify the environment + + fooContent := "foo" + emptyContent := "" + + tests := map[string]struct { + existingLogContent string + + runError bool + usageErrorReturn bool + logDirError bool + + wantOldLogFileContent *string + }{ + "Run and exit successfully despite logs not being written": {logDirError: true}, + "Existing log file has been renamed to old": {existingLogContent: "foo", wantOldLogFileContent: &fooContent}, + "Existing empty log file has been renamed to old": {existingLogContent: "-", wantOldLogFileContent: &emptyContent}, + "Ignore when failing to archive log file": {existingLogContent: "OLD_IS_DIRECTORY"}, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + // Not parallel because we modify the environment + + home := t.TempDir() + appData := filepath.Join(home, "AppData/Local") + + t.Setenv("UserProfile", home) + t.Setenv("LocalAppData", appData) + + a := agent.New(agent.WithRegistry(registry.NewMock())) + + var logFile, oldLogFile string + publicDir, err := a.PublicDir() + if err == nil { + logFile = filepath.Join(publicDir, "log") + oldLogFile = logFile + ".old" + switch tc.existingLogContent { + case "": + case "OLD_IS_DIRECTORY": + err := os.Mkdir(oldLogFile, 0700) + require.NoError(t, err, "Setup: create invalid log.old file") + err = os.WriteFile(logFile, []byte("Old log content"), 0600) + require.NoError(t, err, "Setup: creating pre-existing log file") + case "-": + tc.existingLogContent = "" + fallthrough + default: + err := os.WriteFile(logFile, []byte(tc.existingLogContent), 0600) + require.NoError(t, err, "Setup: creating pre-existing log file") + } + } + + ch := make(chan struct{}) + go func() { + _ = a.Run() // This always returns an error because the gRPC server is stopped + close(ch) + }() + + a.WaitReady() + + select { + case <-ch: + require.Fail(t, "Run should not exit") + default: + } + + a.Quit() + + select { + case <-time.After(20 * time.Second): + require.Fail(t, "Run should have exited") + default: + } + + // Don't check for log files if the directory was not writable + if logFile == "" { + return + } + if tc.wantOldLogFileContent != nil { + require.FileExists(t, oldLogFile, "Old log file should exist") + content, err := os.ReadFile(oldLogFile) + require.NoError(t, err, "Should be able to read old log file") + require.Equal(t, tc.existingLogContent, string(content), "Old log file content should be log's content") + } else { + require.NoFileExists(t, oldLogFile, "Old log file should not exist") + } + }) + } +} + +func TestClean(t *testing.T) { + // Not parallel because we modify the environment + + testCases := map[string]struct { + emptyUserProfile bool + emptyLocalAppDir bool + + wantErr bool + }{ + "Success": {}, + + "Error when %UserProfile% is empty": {emptyUserProfile: true, wantErr: true}, + "Error when %LocalAppData% is empty": {emptyLocalAppDir: true, wantErr: true}, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Not parallel because we modify the environment + + home := t.TempDir() + appData := filepath.Join(home, "AppData/Local") + + t.Setenv("LocalAppData", appData) + + if tc.emptyUserProfile { + t.Setenv("UserProfile", "") + } else { + t.Setenv("UserProfile", home) + + err := os.MkdirAll(filepath.Join(home, common.UserProfileDir), 0700) + require.NoError(t, err, "Setup: could not crate fake public directory") + + err = os.WriteFile(filepath.Join(home, common.UserProfileDir, "file"), []byte("test file"), 0600) + require.NoError(t, err, "Setup: could not write file inside the public directory") + + err = os.WriteFile(filepath.Join(home, ".unrelated"), []byte("test file"), 0600) + require.NoError(t, err, "Setup: could not write file outside the public directory") + } + + if tc.emptyLocalAppDir { + t.Setenv("LocalAppData", "") + } else { + t.Setenv("LocalAppData", appData) + + err := os.MkdirAll(filepath.Join(appData, common.LocalAppDataDir), 0700) + require.NoError(t, err, "Setup: could not crate fake private directory") + + err = os.WriteFile(filepath.Join(appData, common.LocalAppDataDir, "file"), []byte("test file"), 0600) + require.NoError(t, err, "Setup: could not write file inside the private directory") + + err = os.WriteFile(filepath.Join(appData, ".unrelated"), []byte("test file"), 0600) + require.NoError(t, err, "Setup: could not write file outside the private directory") + } + + a := agent.New(agent.WithRegistry(registry.NewMock())) + a.SetArgs("clean") + + err := a.Run() + if tc.wantErr { + require.Error(t, err, "Run should return an error") + } else { + require.NoError(t, err, "Run should not return an error") + } + + require.NoFileExists(t, filepath.Join(home, common.UserProfileDir), "Public directory should have been removed") + if !tc.emptyUserProfile { + require.FileExists(t, filepath.Join(home, ".unrelated"), "Unrelated file in home directory should still exist") + } + + require.NoFileExists(t, filepath.Join(appData, common.LocalAppDataDir), "Private directory should have been removed") + if !tc.emptyLocalAppDir { + require.FileExists(t, filepath.Join(appData, ".unrelated"), "Unrelated file in LocalAppData directory should still exist") + } + }) + } +} + // requireGoroutineStarted starts a goroutine and blocks until it has been launched. func requireGoroutineStarted(t *testing.T, f func()) { t.Helper() diff --git a/windows-agent/cmd/ubuntu-pro-agent/agent/clean.go b/windows-agent/cmd/ubuntu-pro-agent/agent/clean.go new file mode 100644 index 000000000..5d624ebe2 --- /dev/null +++ b/windows-agent/cmd/ubuntu-pro-agent/agent/clean.go @@ -0,0 +1,73 @@ +package agent + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/canonical/ubuntu-pro-for-wsl/common" + "github.com/canonical/ubuntu-pro-for-wsl/common/i18n" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +func (a *App) installClean() { + cmd := &cobra.Command{ + Use: "clean", + Short: i18n.G("Removes all the agent's data and exits"), + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + defer log.Debug("clean command finished") + + // Stop the agent so that it doesn't interfere with file removal. + if err := stopAgent(); err != nil { + log.Warningf("could not stop agent: %v", err) + } + + // Clean up the agent's data. + return errors.Join( + cleanLocation("LocalAppData", common.LocalAppDataDir), + cleanLocation("UserProfile", common.UserProfileDir), + ) + }, + } + a.rootCmd.AddCommand(cmd) +} + +// stopAgent stops all other ubuntu-pro-agent instances (but not itself!). +func stopAgent() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + filterPID := fmt.Sprintf("PID ne %d", os.Getpid()) + + //nolint:gosec // The return value of cmdName() is not user input. + out, err := exec.CommandContext(ctx, "taskkill.exe", + "/F", // Force-stop the process + "/IM", cmdName(), // Match the process name + "/FI", filterPID, // Filter out the current process. + ).CombinedOutput() + if err != nil { + return fmt.Errorf("could not stop process %s: %v. %s", cmdName(), err, out) + } + + return nil +} + +func cleanLocation(rootEnv, relpath string) error { + root := os.Getenv(rootEnv) + if root == "" { + return fmt.Errorf("could not clean up location: environment variable %q is not set", rootEnv) + } + + path := filepath.Join(root, relpath) + if err := os.RemoveAll(path); err != nil { + return fmt.Errorf("could not clean up location %s: %v", path, err) + } + + return nil +} diff --git a/windows-agent/cmd/ubuntu-pro-agent/main.go b/windows-agent/cmd/ubuntu-pro-agent/main.go index 3f739fa31..465376a97 100644 --- a/windows-agent/cmd/ubuntu-pro-agent/main.go +++ b/windows-agent/cmd/ubuntu-pro-agent/main.go @@ -3,19 +3,14 @@ package main import ( "context" - "errors" - "fmt" - "io" "os" "os/signal" - "path/filepath" "sync" "syscall" "github.com/canonical/ubuntu-pro-for-wsl/common" "github.com/canonical/ubuntu-pro-for-wsl/common/i18n" "github.com/canonical/ubuntu-pro-for-wsl/windows-agent/cmd/ubuntu-pro-agent/agent" - "github.com/canonical/ubuntu-pro-for-wsl/windows-agent/internal/consts" log "github.com/sirupsen/logrus" ) @@ -37,17 +32,6 @@ type app interface { func run(a app) int { defer installSignalHandler(a)() - log.SetFormatter(&log.TextFormatter{ - DisableQuote: true, - }) - - cleanup, err := setLoggerOutput(a) - if err != nil { - log.Warningf("could not set logger output: %v", err) - } else { - defer cleanup() - } - if err := a.Run(); err != nil { log.Error(context.Background(), err) @@ -60,36 +44,6 @@ func run(a app) int { return 0 } -func setLoggerOutput(a app) (func(), error) { - publicDir, err := a.PublicDir() - if err != nil { - return nil, err - } - - logFile := filepath.Join(publicDir, "log") - - // Move old log file - oldLogFile := filepath.Join(publicDir, "log.old") - err = os.Rename(logFile, oldLogFile) - if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Warnf("Could not archive previous log file: %v", err) - } - - f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE, 0600) - if err != nil { - return nil, fmt.Errorf("could not open log file: %v", err) - } - - // Write both to file and to Stdout. The latter is useful for local development. - w := io.MultiWriter(f, os.Stdout) - log.SetOutput(w) - - fmt.Fprintf(f, "\n======= STARTUP =======\n") - log.Infof("Version: %s", consts.Version) - - return func() { _ = f.Close() }, nil -} - func installSignalHandler(a app) func() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) diff --git a/windows-agent/cmd/ubuntu-pro-agent/main_test.go b/windows-agent/cmd/ubuntu-pro-agent/main_test.go index 3a86b2fce..a415b463c 100644 --- a/windows-agent/cmd/ubuntu-pro-agent/main_test.go +++ b/windows-agent/cmd/ubuntu-pro-agent/main_test.go @@ -2,8 +2,6 @@ package main import ( "errors" - "os" - "path/filepath" "testing" "time" @@ -44,9 +42,6 @@ func (a *myApp) PublicDir() (string, error) { func TestRun(t *testing.T) { t.Parallel() - fooContent := "foo" - emptyContent := "" - tests := map[string]struct { existingLogContent string @@ -57,13 +52,7 @@ func TestRun(t *testing.T) { wantReturnCode int wantOldLogFileContent *string }{ - "Run and exit successfully": {}, - "Run and exit successfully despite logs not being written": {logDirError: true}, - - // Log file handling - "Existing log file has been renamed to old": {existingLogContent: "foo", wantOldLogFileContent: &fooContent}, - "Existing empty log file has been renamed to old": {existingLogContent: "-", wantOldLogFileContent: &emptyContent}, - "Ignore when failing to archive log file": {existingLogContent: "OLD_IS_DIRECTORY", wantReturnCode: 0}, + "Run and exit successfully": {}, // Error cases "Run and return error": {runError: true, wantReturnCode: 1}, @@ -81,31 +70,6 @@ func TestRun(t *testing.T) { tmpDir: t.TempDir(), } - if tc.logDirError { - a.tmpDir = "PUBLIC_DIR_ERROR" - } - - var logFile, oldLogFile string - publicDir, err := a.PublicDir() - if err == nil { - logFile = filepath.Join(publicDir, "log") - oldLogFile = logFile + ".old" - switch tc.existingLogContent { - case "": - case "OLD_IS_DIRECTORY": - err := os.Mkdir(oldLogFile, 0700) - require.NoError(t, err, "Setup: create invalid log.old file") - err = os.WriteFile(logFile, []byte("Old log content"), 0600) - require.NoError(t, err, "Setup: creating pre-existing log file") - case "-": - tc.existingLogContent = "" - fallthrough - default: - err := os.WriteFile(logFile, []byte(tc.existingLogContent), 0600) - require.NoError(t, err, "Setup: creating pre-existing log file") - } - } - var rc int wait := make(chan struct{}) go func() { @@ -119,19 +83,6 @@ func TestRun(t *testing.T) { <-wait require.Equal(t, tc.wantReturnCode, rc, "Return expected code") - - // Don't check for log files if the directory was not writable - if logFile == "" { - return - } - if tc.wantOldLogFileContent != nil { - require.FileExists(t, oldLogFile, "Old log file should exist") - content, err := os.ReadFile(oldLogFile) - require.NoError(t, err, "Should be able to read old log file") - require.Equal(t, tc.existingLogContent, string(content), "Old log file content should be log's content") - } else { - require.NoFileExists(t, oldLogFile, "Old log file should not exist") - } }) } }