diff --git a/pkg/os/shell/shell.go b/pkg/os/shell/shell.go index 96af153041..6a307bf493 100644 --- a/pkg/os/shell/shell.go +++ b/pkg/os/shell/shell.go @@ -1,16 +1,23 @@ package shell import ( + "errors" "fmt" "os" + "slices" "strings" + "github.com/shirou/gopsutil/v4/process" + "github.com/spf13/cast" + crcos "github.com/crc-org/crc/v2/pkg/os" ) var ( CommandRunner = crcos.NewLocalCommandRunner() WindowsSubsystemLinuxKernelMetadataFile = "/proc/version" + ErrUnknownShell = errors.New("Error: Unknown shell") + currentProcessSupplier = createCurrentProcess ) type Config struct { @@ -20,6 +27,20 @@ type Config struct { PathSuffix string } +// AbstractProcess is an interface created to abstract operations of the gopsutil library +// It is created so that we can override the behavior while writing unit tests by providing +// a mock implementation. +type AbstractProcess interface { + Name() (string, error) + Parent() (AbstractProcess, error) +} + +// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's +// process.Process object. This implementation is used in production code. +type RealProcess struct { + *process.Process +} + func GetShell(userShell string) (string, error) { if userShell != "" { if !isSupportedShell(userShell) { @@ -151,3 +172,52 @@ func IsWindowsSubsystemLinux() bool { } return false } + +func (p *RealProcess) Parent() (AbstractProcess, error) { + parentProcess, err := p.Process.Parent() + if err != nil { + return nil, err + } + return &RealProcess{parentProcess}, nil +} + +func createCurrentProcess() AbstractProcess { + currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid())) + if err != nil { + return nil + } + return &RealProcess{currentProcess} +} + +// detectShellByCheckingProcessTree attempts to identify the shell being used by +// examining the process tree starting from the given process ID. This function +// traverses up to ProcessDepthLimit levels up the process hierarchy. +// Parameters: +// - pid (int): The process ID to start checking from. +// +// Returns: +// - string: The name of the shell if found (e.g., "zsh", "bash", "fish"); +// otherwise, an empty string is returned if no matching shell is detected +// or an error occurs during the process tree traversal. +// +// Examples: +// +// shellName := detectShellByCheckingProcessTree(1234) +func detectShellByCheckingProcessTree(p AbstractProcess) string { + for p != nil { + processName, err := p.Name() + if err != nil { + return "" + } + if slices.ContainsFunc(supportedShell, func(listElem string) bool { + return strings.HasPrefix(processName, listElem) + }) { + return processName + } + p, err = p.Parent() + if err != nil { + return "" + } + } + return "" +} diff --git a/pkg/os/shell/shell_darwin.go b/pkg/os/shell/shell_darwin.go deleted file mode 100644 index a7038da061..0000000000 --- a/pkg/os/shell/shell_darwin.go +++ /dev/null @@ -1,5 +0,0 @@ -package shell - -var ( - supportedShell = []string{"bash", "zsh", "fish"} -) diff --git a/pkg/os/shell/shell_linux.go b/pkg/os/shell/shell_linux.go deleted file mode 100644 index a7038da061..0000000000 --- a/pkg/os/shell/shell_linux.go +++ /dev/null @@ -1,5 +0,0 @@ -package shell - -var ( - supportedShell = []string{"bash", "zsh", "fish"} -) diff --git a/pkg/os/shell/shell_test.go b/pkg/os/shell/shell_test.go index a707555663..d0ba51133e 100644 --- a/pkg/os/shell/shell_test.go +++ b/pkg/os/shell/shell_test.go @@ -1,6 +1,7 @@ package shell import ( + "errors" "os" "path/filepath" "testing" @@ -47,6 +48,28 @@ func (e *MockCommandRunner) RunPrivileged(_ string, cmdAndArgs ...string) (strin return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn } +// MockedProcess is a mock implementation of AbstractProcess for testing purposes. +type MockedProcess struct { + name string + parent *MockedProcess + nameGetFails bool + parentGetFails bool +} + +func (m MockedProcess) Parent() (AbstractProcess, error) { + if m.parentGetFails || m.parent == nil { + return nil, errors.New("failed to get the pid") + } + return m.parent, nil +} + +func (m MockedProcess) Name() (string, error) { + if m.nameGetFails { + return "", errors.New("failed to get the name") + } + return m.name, nil +} + func TestGetPathEnvString(t *testing.T) { tests := []struct { name string @@ -179,3 +202,16 @@ func TestConvertToWindowsSubsystemLinuxPath(t *testing.T) { assert.Equal(t, "wsl", mockCommandExecutor.commandName) assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs) } + +func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess { + if len(processes) == 0 { + return nil + } + head := &processes[0] + current := head + for i := 1; i < len(processes); i++ { + current.parent = &processes[i] + current = current.parent + } + return head +} diff --git a/pkg/os/shell/shell_unix.go b/pkg/os/shell/shell_unix.go index 649f3882f5..43df1f7b80 100644 --- a/pkg/os/shell/shell_unix.go +++ b/pkg/os/shell/shell_unix.go @@ -4,50 +4,14 @@ package shell import ( - "errors" "fmt" - "os" "path/filepath" - - "github.com/shirou/gopsutil/v4/process" - "github.com/spf13/cast" ) var ( - ErrUnknownShell = errors.New("Error: Unknown shell") - currentProcessSupplier = createCurrentProcess + supportedShell = []string{"bash", "zsh", "fish"} ) -// AbstractProcess is an interface created to abstract operations of the gopsutil library -// It is created so that we can override the behavior while writing unit tests by providing -// a mock implementation. -type AbstractProcess interface { - Name() (string, error) - Parent() (AbstractProcess, error) -} - -// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's -// process.Process object. This implementation is used in production code. -type RealProcess struct { - *process.Process -} - -func (p *RealProcess) Parent() (AbstractProcess, error) { - parentProcess, err := p.Process.Parent() - if err != nil { - return nil, err - } - return &RealProcess{parentProcess}, nil -} - -func createCurrentProcess() AbstractProcess { - currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid())) - if err != nil { - return nil - } - return &RealProcess{currentProcess} -} - // detect detects user's current shell. func detect() (string, error) { detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier()) @@ -58,34 +22,3 @@ func detect() (string, error) { return filepath.Base(detectedShell), nil } - -// detectShellByCheckingProcessTree attempts to identify the shell being used by -// examining the process tree starting from the given process ID. This function -// traverses up to ProcessDepthLimit levels up the process hierarchy. -// Parameters: -// - pid (int): The process ID to start checking from. -// -// Returns: -// - string: The name of the shell if found (e.g., "zsh", "bash", "fish"); -// otherwise, an empty string is returned if no matching shell is detected -// or an error occurs during the process tree traversal. -// -// Examples: -// -// shellName := detectShellByCheckingProcessTree(1234) -func detectShellByCheckingProcessTree(p AbstractProcess) string { - for p != nil { - processName, err := p.Name() - if err != nil { - return "" - } - if processName == "zsh" || processName == "bash" || processName == "fish" { - return processName - } - p, err = p.Parent() - if err != nil { - return "" - } - } - return "" -} diff --git a/pkg/os/shell/shell_unix_test.go b/pkg/os/shell/shell_unix_test.go index 524825d134..a9c24cc086 100644 --- a/pkg/os/shell/shell_unix_test.go +++ b/pkg/os/shell/shell_unix_test.go @@ -5,35 +5,12 @@ package shell import ( "bytes" - "errors" "os" "testing" "github.com/stretchr/testify/assert" ) -// MockedProcess is a mock implementation of AbstractProcess for testing purposes. -type MockedProcess struct { - name string - parent *MockedProcess - nameGetFails bool - parentGetFails bool -} - -func (m MockedProcess) Parent() (AbstractProcess, error) { - if m.parentGetFails || m.parent == nil { - return nil, errors.New("failed to get the pid") - } - return m.parent, nil -} - -func (m MockedProcess) Name() (string, error) { - if m.nameGetFails { - return "", errors.New("failed to get the name") - } - return m.name, nil -} - func TestUnknownShell(t *testing.T) { tests := []struct { name string @@ -183,16 +160,3 @@ func TestGetCurrentProcess(t *testing.T) { assert.NoError(t, err) assert.Greater(t, len(currentProcessName), 0) } - -func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess { - if len(processes) == 0 { - return nil - } - head := &processes[0] - current := head - for i := 1; i < len(processes); i++ { - current.parent = &processes[i] - current = current.parent - } - return head -} diff --git a/pkg/os/shell/shell_windows.go b/pkg/os/shell/shell_windows.go index 52f26654b3..307a85a1a4 100644 --- a/pkg/os/shell/shell_windows.go +++ b/pkg/os/shell/shell_windows.go @@ -1,64 +1,18 @@ package shell import ( - "fmt" - "math" - "os" - "path/filepath" + "slices" "sort" "strconv" "strings" - "syscall" - "unsafe" "github.com/crc-org/crc/v2/pkg/crc/logging" ) var ( - supportedShell = []string{"cmd", "powershell", "bash", "zsh", "fish"} + supportedShell = []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"} ) -// re-implementation of private function in https://github.com/golang/go/blob/master/src/syscall/syscall_windows.go -func getProcessEntry(pid uint32) (pe *syscall.ProcessEntry32, err error) { - snapshot, err := syscall.CreateToolhelp32Snapshot(syscall.TH32CS_SNAPPROCESS, 0) - if err != nil { - return nil, err - } - defer func() { - _ = syscall.CloseHandle(syscall.Handle(snapshot)) - }() - - var processEntry syscall.ProcessEntry32 - processEntry.Size = uint32(unsafe.Sizeof(processEntry)) - err = syscall.Process32First(snapshot, &processEntry) - if err != nil { - return nil, err - } - - for { - if processEntry.ProcessID == pid { - pe = &processEntry - return - } - - err = syscall.Process32Next(snapshot, &processEntry) - if err != nil { - return nil, err - } - } -} - -// getNameAndItsPpid returns the exe file name its parent process id. -func getNameAndItsPpid(pid uint32) (exefile string, parentid uint32, err error) { - pe, err := getProcessEntry(pid) - if err != nil { - return "", 0, err - } - - name := syscall.UTF16ToString(pe.ExeFile[:]) - return name, pe.ParentProcessID, nil -} - func shellType(shell string, defaultShell string) string { switch { case strings.Contains(strings.ToLower(shell), "powershell"): @@ -69,7 +23,7 @@ func shellType(shell string, defaultShell string) string { return "cmd" case strings.Contains(strings.ToLower(shell), "wsl"): return detectShellByInvokingCommand("bash", "wsl", []string{"-e", "bash", "-c", "ps -ao pid=,comm="}) - case filepath.IsAbs(shell) && strings.Contains(strings.ToLower(shell), "bash"): + case strings.Contains(strings.ToLower(shell), "bash"): return "bash" default: return defaultShell @@ -77,31 +31,7 @@ func shellType(shell string, defaultShell string) string { } func detect() (string, error) { - shell := os.Getenv("SHELL") - - if shell == "" { - pid := os.Getppid() - if pid < 0 || pid > math.MaxUint32 { - return "", fmt.Errorf("integer overflow for pid: %v", pid) - } - shell, shellppid, err := getNameAndItsPpid(uint32(pid)) - if err != nil { - return "cmd", err // defaulting to cmd - } - shell = shellType(shell, "") - if shell == "" { - shell, _, err := getNameAndItsPpid(shellppid) - if err != nil { - return "cmd", err // defaulting to cmd - } - return shellType(shell, "cmd"), nil - } - return shell, nil - } - - if os.Getenv("__fish_bin_dir") != "" { - return "fish", nil - } + shell := detectShellByCheckingProcessTree(currentProcessSupplier()) return shellType(shell, "cmd"), nil } @@ -163,9 +93,9 @@ func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string { lines := strings.Split(psCommandOutput, "\n") for _, line := range lines { lineParts := strings.Split(strings.TrimSpace(line), " ") - if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") || - strings.Contains(lineParts[1], "bash") || - strings.Contains(lineParts[1], "fish")) { + if len(lineParts) == 2 && slices.ContainsFunc(supportedShell, func(listElem string) bool { + return strings.HasPrefix(lineParts[1], listElem) + }) { parsedProcessID, err := strconv.Atoi(lineParts[0]) if err == nil { processOutputs = append(processOutputs, ProcessOutput{ diff --git a/pkg/os/shell/shell_windows_test.go b/pkg/os/shell/shell_windows_test.go index 06fecdad79..48efa2e34a 100644 --- a/pkg/os/shell/shell_windows_test.go +++ b/pkg/os/shell/shell_windows_test.go @@ -1,51 +1,137 @@ package shell import ( - "math" - "os" "testing" "github.com/stretchr/testify/assert" ) -func TestDetect(t *testing.T) { - defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) - os.Setenv("SHELL", "") - - shell, err := detect() +func TestDetect_WhenUnknownShell_ThenDefaultToCmdShell(t *testing.T) { + tests := []struct { + name string + processTree []MockedProcess + expectedShellType string + }{ + { + "failure to get process details for given pid", + []MockedProcess{}, + "", + }, + { + "failure while getting name of process", + []MockedProcess{ + { + name: "crc.exe", + }, + { + nameGetFails: true, + }, + }, + "", + }, + { + "failure while getting ppid of process", + []MockedProcess{ + { + name: "crc.exe", + }, + { + parentGetFails: true, + }, + }, + "", + }, + { + "failure when no shell process in process tree", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "unknown.exe", + }, + }, + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + currentProcessSupplier = func() AbstractProcess { + return createNewMockProcessTreeFrom(tt.processTree) + } - assert.Contains(t, supportedShell, shell) - assert.NoError(t, err) -} + // When + shell, err := detect() -func TestGetNameAndItsPpidOfCurrent(t *testing.T) { - pid := os.Getpid() - if pid < 0 || pid > math.MaxUint32 { - assert.Fail(t, "integer overflow detected") - } - shell, shellppid, err := getNameAndItsPpid(uint32(pid)) - assert.Equal(t, "shell.test.exe", shell) - ppid := os.Getppid() - if ppid < 0 || ppid > math.MaxUint32 { - assert.Fail(t, "integer overflow detected") + // Then + assert.NoError(t, err) + assert.Equal(t, "cmd", shell) + }) } - assert.Equal(t, uint32(ppid), shellppid) - assert.NoError(t, err) } -func TestGetNameAndItsPpidOfParent(t *testing.T) { - pid := os.Getppid() - if pid < 0 || pid > math.MaxUint32 { - assert.Fail(t, "integer overflow detected") +func TestDetect_GivenProcessTree_ThenReturnShellProcessWithCorrespondingParentPID(t *testing.T) { + tests := []struct { + name string + processTree []MockedProcess + expectedShellType string + }{ + { + "bash shell, then detect bash shell", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "bash.exe", + }, + }, + "bash", + }, + { + "powershell, then detect powershell", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "powershell.exe", + }, + }, + "powershell", + }, + { + "cmd shell, then detect fish shell", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "cmd.exe", + }, + }, + "cmd", + }, } - shell, _, err := getNameAndItsPpid(uint32(pid)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + currentProcessSupplier = func() AbstractProcess { + return createNewMockProcessTreeFrom(tt.processTree) + } + // When + shell, err := detect() - assert.Equal(t, "go.exe", shell) - assert.NoError(t, err) + // Then + assert.Equal(t, tt.expectedShellType, shell) + assert.NoError(t, err) + }) + } } func TestSupportedShells(t *testing.T) { - assert.Equal(t, []string{"cmd", "powershell", "bash", "zsh", "fish"}, supportedShell) + assert.Equal(t, []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"}, supportedShell) } func TestShellType(t *testing.T) {