From 04f1995c87e5135643ee05c58cefb10a2499881d Mon Sep 17 00:00:00 2001 From: Rohan Kumar Date: Thu, 19 Dec 2024 12:24:11 +0530 Subject: [PATCH] fix (shell) : Improve shell detection on windows (#3767) We should detect the usage of $SHELL environment variable when using CRC from linux like environments on Windows. We should also convert the CRC binary paths to unix path format whenever unix shells are detected. Signed-off-by: Rohan Kumar --- pkg/os/shell/shell.go | 14 ++++++- pkg/os/shell/shell_test.go | 65 ++++++++++++++++++++---------- pkg/os/shell/shell_unix_test.go | 20 +++++++++ pkg/os/shell/shell_windows.go | 7 +++- pkg/os/shell/shell_windows_test.go | 26 ++++++++++++ 5 files changed, 107 insertions(+), 25 deletions(-) diff --git a/pkg/os/shell/shell.go b/pkg/os/shell/shell.go index ccd776edc6..10a0b46141 100644 --- a/pkg/os/shell/shell.go +++ b/pkg/os/shell/shell.go @@ -67,7 +67,7 @@ func GetEnvString(userShell string, envName string, envValue string) string { case "fish": return fmt.Sprintf("contains %s $fish_user_paths; or set -U fish_user_paths %s $fish_user_paths", envValue, envValue) default: - return fmt.Sprintf("export %s=\"%s\"", envName, envValue) + return fmt.Sprintf("export %s=\"%s\"", envName, convertToLinuxStylePath(userShell, envValue)) } } @@ -81,8 +81,18 @@ func GetPathEnvString(userShell string, prependedPath string) string { case "cmd": pathStr = fmt.Sprintf("%s;%%PATH%%", prependedPath) default: - pathStr = fmt.Sprintf("%s:$PATH", prependedPath) + pathStr = fmt.Sprintf("%s:$PATH", convertToLinuxStylePath(userShell, prependedPath)) } return GetEnvString(userShell, "PATH", pathStr) } + +func convertToLinuxStylePath(userShell string, path string) string { + if strings.Contains(path, "\\") && + (userShell == "bash" || userShell == "zsh" || userShell == "fish") { + path = strings.Replace(path, ":", "", -1) + path = strings.Replace(path, "\\", "/", -1) + return fmt.Sprintf("/%s", path) + } + return path +} diff --git a/pkg/os/shell/shell_test.go b/pkg/os/shell/shell_test.go index e9e7a002b8..1f54cd8d2a 100644 --- a/pkg/os/shell/shell_test.go +++ b/pkg/os/shell/shell_test.go @@ -1,31 +1,52 @@ -//go:build !windows -// +build !windows - package shell import ( - "os" "testing" - - "github.com/stretchr/testify/assert" ) -func TestDetectBash(t *testing.T) { - defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) - os.Setenv("SHELL", "/bin/bash") - - shell, err := detect() - - assert.Equal(t, "bash", shell) - assert.NoError(t, err) +func TestGetPathEnvString(t *testing.T) { + tests := []struct { + name string + userShell string + path string + expectedStr string + }{ + {"fish shell", "fish", "C:\\Users\\foo\\.crc\\bin\\oc", "contains C:\\Users\\foo\\.crc\\bin\\oc $fish_user_paths; or set -U fish_user_paths C:\\Users\\foo\\.crc\\bin\\oc $fish_user_paths"}, + {"powershell shell", "powershell", "C:\\Users\\foo\\oc.exe", "$Env:PATH = \"C:\\Users\\foo\\oc.exe;$Env:PATH\""}, + {"cmd shell", "cmd", "C:\\Users\\foo\\oc.exe", "SET PATH=C:\\Users\\foo\\oc.exe;%PATH%"}, + {"bash with windows path", "bash", "C:\\Users\\foo.exe", "export PATH=\"/C/Users/foo.exe:$PATH\""}, + {"unknown with windows path", "unknown", "C:\\Users\\foo.exe", "export PATH=\"C:\\Users\\foo.exe:$PATH\""}, + {"unknown shell with unix path", "unknown", "/home/foo/.crc/bin/oc", "export PATH=\"/home/foo/.crc/bin/oc:$PATH\""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetPathEnvString(tt.userShell, tt.path) + if result != tt.expectedStr { + t.Errorf("GetPathEnvString(%s, %s) = %s; want %s", tt.userShell, tt.path, result, tt.expectedStr) + } + }) + } } -func TestDetectFish(t *testing.T) { - defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) - os.Setenv("SHELL", "/bin/fish") - - shell, err := detect() - - assert.Equal(t, "fish", shell) - assert.NoError(t, err) +func TestConvertToLinuxStylePath(t *testing.T) { + tests := []struct { + name string + userShell string + path string + expectedPath string + }{ + {"bash on windows, should convert", "bash", "C:\\Users\\foo\\.crc\\bin\\oc", "/C/Users/foo/.crc/bin/oc"}, + {"zsh on windows, should convert", "zsh", "C:\\Users\\foo\\.crc\\bin\\oc", "/C/Users/foo/.crc/bin/oc"}, + {"fish on windows, should convert", "fish", "C:\\Users\\foo\\.crc\\bin\\oc", "/C/Users/foo/.crc/bin/oc"}, + {"powershell on windows, should NOT convert", "powershell", "C:\\Users\\foo\\.crc\\bin\\oc", "C:\\Users\\foo\\.crc\\bin\\oc"}, + {"cmd on windows, should NOT convert", "cmd", "C:\\Users\\foo\\.crc\\bin\\oc", "C:\\Users\\foo\\.crc\\bin\\oc"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertToLinuxStylePath(tt.userShell, tt.path) + if result != tt.expectedPath { + t.Errorf("convertToLinuxStylePath(%s, %s) = %s; want %s", tt.userShell, tt.path, result, tt.expectedPath) + } + }) + } } diff --git a/pkg/os/shell/shell_unix_test.go b/pkg/os/shell/shell_unix_test.go index 18fb907f65..68d6e1989e 100644 --- a/pkg/os/shell/shell_unix_test.go +++ b/pkg/os/shell/shell_unix_test.go @@ -19,3 +19,23 @@ func TestUnknownShell(t *testing.T) { assert.Equal(t, err, ErrUnknownShell) assert.Empty(t, shell) } + +func TestDetectBash(t *testing.T) { + defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) + os.Setenv("SHELL", "/bin/bash") + + shell, err := detect() + + assert.Equal(t, "bash", shell) + assert.NoError(t, err) +} + +func TestDetectFish(t *testing.T) { + defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) + os.Setenv("SHELL", "/bin/fish") + + shell, err := detect() + + assert.Equal(t, "fish", shell) + assert.NoError(t, err) +} diff --git a/pkg/os/shell/shell_windows.go b/pkg/os/shell/shell_windows.go index 5075bb7402..263191d970 100644 --- a/pkg/os/shell/shell_windows.go +++ b/pkg/os/shell/shell_windows.go @@ -4,13 +4,14 @@ import ( "fmt" "math" "os" + "path/filepath" "strings" "syscall" "unsafe" ) var ( - supportedShell = []string{"cmd", "powershell"} + supportedShell = []string{"cmd", "powershell", "bash", "zsh"} ) // re-implementation of private function in https://github.com/golang/go/blob/master/src/syscall/syscall_windows.go @@ -62,6 +63,10 @@ func shellType(shell string, defaultShell string) string { return "powershell" case strings.Contains(strings.ToLower(shell), "cmd"): return "cmd" + case filepath.IsAbs(shell) && strings.Contains(strings.ToLower(shell), "bash"): + return "bash" + case filepath.IsAbs(shell) && strings.Contains(strings.ToLower(shell), "zsh"): + return "zsh" default: return defaultShell } diff --git a/pkg/os/shell/shell_windows_test.go b/pkg/os/shell/shell_windows_test.go index 381b2947c9..a15b9b49cf 100644 --- a/pkg/os/shell/shell_windows_test.go +++ b/pkg/os/shell/shell_windows_test.go @@ -43,3 +43,29 @@ func TestGetNameAndItsPpidOfParent(t *testing.T) { assert.Equal(t, "go.exe", shell) assert.NoError(t, err) } + +func TestSupportedShells(t *testing.T) { + assert.Equal(t, []string{"cmd", "powershell", "bash", "zsh"}, supportedShell) +} + +func TestShellType(t *testing.T) { + tests := []struct { + name string + userShell string + expectedShellType string + }{ + {"git bash", "C:\\Program Files\\Git\\usr\\bin\\bash.exe", "bash"}, + {"powershell", "powershell", "powershell"}, + {"cmd.exe", "cmd.exe", "cmd"}, + {"pwsh", "pwsh.exe", "powershell"}, + {"empty value", "", "cmd"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shellType(tt.userShell, "cmd") + if result != tt.expectedShellType { + t.Errorf("shellType(%s) = %s; want %s", tt.userShell, result, tt.expectedShellType) + } + }) + } +}