diff --git a/pkg/os/shell/shell.go b/pkg/os/shell/shell.go index fbdac6aebf..2eb8da7164 100644 --- a/pkg/os/shell/shell.go +++ b/pkg/os/shell/shell.go @@ -6,6 +6,8 @@ import ( "os" "strings" + crcstrings "github.com/crc-org/crc/v2/pkg/strings" + "github.com/shirou/gopsutil/v4/process" "github.com/spf13/cast" @@ -208,7 +210,9 @@ func detectShellByCheckingProcessTree(p AbstractProcess) string { if err != nil { return "" } - if processName == "zsh" || processName == "bash" || processName == "fish" { + if crcstrings.IsPresentInListSatisfying(supportedShell, processName, func(listElem string, toMatch string) bool { + return strings.HasPrefix(toMatch, listElem) + }) { return processName } p, err = p.Parent() diff --git a/pkg/os/shell/shell_unix.go b/pkg/os/shell/shell_unix.go index 614060a76c..43df1f7b80 100644 --- a/pkg/os/shell/shell_unix.go +++ b/pkg/os/shell/shell_unix.go @@ -8,6 +8,10 @@ import ( "path/filepath" ) +var ( + supportedShell = []string{"bash", "zsh", "fish"} +) + // detect detects user's current shell. func detect() (string, error) { detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier()) diff --git a/pkg/os/shell/shell_windows.go b/pkg/os/shell/shell_windows.go index 52f26654b3..fac1174adc 100644 --- a/pkg/os/shell/shell_windows.go +++ b/pkg/os/shell/shell_windows.go @@ -1,64 +1,19 @@ package shell import ( - "fmt" - "math" - "os" - "path/filepath" "sort" "strconv" "strings" - "syscall" - "unsafe" + + crcstrings "github.com/crc-org/crc/v2/pkg/strings" "github.com/crc-org/crc/v2/pkg/crc/logging" ) var ( - supportedShell = []string{"cmd", "powershell", "bash", "zsh", "fish"} + supportedShell = []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"} ) -// re-implementation of private function in https://github.com/golang/go/blob/master/src/syscall/syscall_windows.go -func getProcessEntry(pid uint32) (pe *syscall.ProcessEntry32, err error) { - snapshot, err := syscall.CreateToolhelp32Snapshot(syscall.TH32CS_SNAPPROCESS, 0) - if err != nil { - return nil, err - } - defer func() { - _ = syscall.CloseHandle(syscall.Handle(snapshot)) - }() - - var processEntry syscall.ProcessEntry32 - processEntry.Size = uint32(unsafe.Sizeof(processEntry)) - err = syscall.Process32First(snapshot, &processEntry) - if err != nil { - return nil, err - } - - for { - if processEntry.ProcessID == pid { - pe = &processEntry - return - } - - err = syscall.Process32Next(snapshot, &processEntry) - if err != nil { - return nil, err - } - } -} - -// getNameAndItsPpid returns the exe file name its parent process id. -func getNameAndItsPpid(pid uint32) (exefile string, parentid uint32, err error) { - pe, err := getProcessEntry(pid) - if err != nil { - return "", 0, err - } - - name := syscall.UTF16ToString(pe.ExeFile[:]) - return name, pe.ParentProcessID, nil -} - func shellType(shell string, defaultShell string) string { switch { case strings.Contains(strings.ToLower(shell), "powershell"): @@ -69,7 +24,7 @@ func shellType(shell string, defaultShell string) string { return "cmd" case strings.Contains(strings.ToLower(shell), "wsl"): return detectShellByInvokingCommand("bash", "wsl", []string{"-e", "bash", "-c", "ps -ao pid=,comm="}) - case filepath.IsAbs(shell) && strings.Contains(strings.ToLower(shell), "bash"): + case strings.Contains(strings.ToLower(shell), "bash"): return "bash" default: return defaultShell @@ -77,31 +32,7 @@ func shellType(shell string, defaultShell string) string { } func detect() (string, error) { - shell := os.Getenv("SHELL") - - if shell == "" { - pid := os.Getppid() - if pid < 0 || pid > math.MaxUint32 { - return "", fmt.Errorf("integer overflow for pid: %v", pid) - } - shell, shellppid, err := getNameAndItsPpid(uint32(pid)) - if err != nil { - return "cmd", err // defaulting to cmd - } - shell = shellType(shell, "") - if shell == "" { - shell, _, err := getNameAndItsPpid(shellppid) - if err != nil { - return "cmd", err // defaulting to cmd - } - return shellType(shell, "cmd"), nil - } - return shell, nil - } - - if os.Getenv("__fish_bin_dir") != "" { - return "fish", nil - } + shell := detectShellByCheckingProcessTree(currentProcessSupplier()) return shellType(shell, "cmd"), nil } @@ -163,9 +94,9 @@ func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string { lines := strings.Split(psCommandOutput, "\n") for _, line := range lines { lineParts := strings.Split(strings.TrimSpace(line), " ") - if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") || - strings.Contains(lineParts[1], "bash") || - strings.Contains(lineParts[1], "fish")) { + if len(lineParts) == 2 && crcstrings.IsPresentInListSatisfying(supportedShell, lineParts[1], func(listElem string, toMatch string) bool { + return strings.HasPrefix(toMatch, listElem) + }) { parsedProcessID, err := strconv.Atoi(lineParts[0]) if err == nil { processOutputs = append(processOutputs, ProcessOutput{ diff --git a/pkg/os/shell/shell_windows_test.go b/pkg/os/shell/shell_windows_test.go index 06fecdad79..48efa2e34a 100644 --- a/pkg/os/shell/shell_windows_test.go +++ b/pkg/os/shell/shell_windows_test.go @@ -1,51 +1,137 @@ package shell import ( - "math" - "os" "testing" "github.com/stretchr/testify/assert" ) -func TestDetect(t *testing.T) { - defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) - os.Setenv("SHELL", "") - - shell, err := detect() +func TestDetect_WhenUnknownShell_ThenDefaultToCmdShell(t *testing.T) { + tests := []struct { + name string + processTree []MockedProcess + expectedShellType string + }{ + { + "failure to get process details for given pid", + []MockedProcess{}, + "", + }, + { + "failure while getting name of process", + []MockedProcess{ + { + name: "crc.exe", + }, + { + nameGetFails: true, + }, + }, + "", + }, + { + "failure while getting ppid of process", + []MockedProcess{ + { + name: "crc.exe", + }, + { + parentGetFails: true, + }, + }, + "", + }, + { + "failure when no shell process in process tree", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "unknown.exe", + }, + }, + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + currentProcessSupplier = func() AbstractProcess { + return createNewMockProcessTreeFrom(tt.processTree) + } - assert.Contains(t, supportedShell, shell) - assert.NoError(t, err) -} + // When + shell, err := detect() -func TestGetNameAndItsPpidOfCurrent(t *testing.T) { - pid := os.Getpid() - if pid < 0 || pid > math.MaxUint32 { - assert.Fail(t, "integer overflow detected") - } - shell, shellppid, err := getNameAndItsPpid(uint32(pid)) - assert.Equal(t, "shell.test.exe", shell) - ppid := os.Getppid() - if ppid < 0 || ppid > math.MaxUint32 { - assert.Fail(t, "integer overflow detected") + // Then + assert.NoError(t, err) + assert.Equal(t, "cmd", shell) + }) } - assert.Equal(t, uint32(ppid), shellppid) - assert.NoError(t, err) } -func TestGetNameAndItsPpidOfParent(t *testing.T) { - pid := os.Getppid() - if pid < 0 || pid > math.MaxUint32 { - assert.Fail(t, "integer overflow detected") +func TestDetect_GivenProcessTree_ThenReturnShellProcessWithCorrespondingParentPID(t *testing.T) { + tests := []struct { + name string + processTree []MockedProcess + expectedShellType string + }{ + { + "bash shell, then detect bash shell", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "bash.exe", + }, + }, + "bash", + }, + { + "powershell, then detect powershell", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "powershell.exe", + }, + }, + "powershell", + }, + { + "cmd shell, then detect fish shell", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "cmd.exe", + }, + }, + "cmd", + }, } - shell, _, err := getNameAndItsPpid(uint32(pid)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + currentProcessSupplier = func() AbstractProcess { + return createNewMockProcessTreeFrom(tt.processTree) + } + // When + shell, err := detect() - assert.Equal(t, "go.exe", shell) - assert.NoError(t, err) + // Then + assert.Equal(t, tt.expectedShellType, shell) + assert.NoError(t, err) + }) + } } func TestSupportedShells(t *testing.T) { - assert.Equal(t, []string{"cmd", "powershell", "bash", "zsh", "fish"}, supportedShell) + assert.Equal(t, []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"}, supportedShell) } func TestShellType(t *testing.T) { diff --git a/pkg/strings/strings.go b/pkg/strings/strings.go index 28af3b79c0..70d750a9c5 100644 --- a/pkg/strings/strings.go +++ b/pkg/strings/strings.go @@ -6,8 +6,14 @@ import ( ) func Contains(input []string, match string) bool { + return IsPresentInListSatisfying(input, match, func(listElement string, toMatch string) bool { + return listElement == toMatch + }) +} + +func IsPresentInListSatisfying(input []string, toMatch string, matchingPredicate func(string, string) bool) bool { for _, v := range input { - if v == match { + if matchingPredicate(v, toMatch) { return true } } diff --git a/pkg/strings/strings_test.go b/pkg/strings/strings_test.go index 08c37e931a..b9994a9661 100644 --- a/pkg/strings/strings_test.go +++ b/pkg/strings/strings_test.go @@ -1,6 +1,7 @@ package strings import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -51,6 +52,54 @@ func TestContains(t *testing.T) { }) } +func TestIsPresentInListSatisfying(t *testing.T) { + tests := []struct { + name string + input []string + toMatch string + matchPredicate func(string, string) bool + expectedOutcome bool + }{ + { + name: "match found with case insensitive filter", + input: []string{"apple", "banana", "cherry"}, + toMatch: "Banana", + matchPredicate: strings.EqualFold, + expectedOutcome: true, + }, + { + name: "no match found with case insensitive filter", + input: []string{"apple", "banana", "cherry"}, + toMatch: "grape", + matchPredicate: strings.EqualFold, + expectedOutcome: false, + }, + { + name: "no match found for empty input list", + input: []string{}, + toMatch: "apple", + matchPredicate: strings.EqualFold, + expectedOutcome: false, + }, + { + name: "match found for prefix ", + input: []string{"cmd", "powershell", "bash"}, + toMatch: "cmd.test.exe", + matchPredicate: strings.HasPrefix, + expectedOutcome: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsPresentInListSatisfying(tt.input, tt.toMatch, tt.matchPredicate) + if result != tt.expectedOutcome { + t.Errorf("IsPresentInListSatisfying returned %v, expecting %v", result, tt.expectedOutcome) + } + }) + } +} + type splitLinesTest struct { input string splitOutput []string