Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor (shell) : detect shell on windows using gopsutil too (#4588) #4591

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions pkg/os/shell/shell.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
package shell

import (
"errors"
"fmt"
"os"
"slices"
"strings"

"github.com/shirou/gopsutil/v4/process"
"github.com/spf13/cast"

crcos "github.com/crc-org/crc/v2/pkg/os"
)

var (
CommandRunner = crcos.NewLocalCommandRunner()
WindowsSubsystemLinuxKernelMetadataFile = "/proc/version"
ErrUnknownShell = errors.New("Error: Unknown shell")
currentProcessSupplier = createCurrentProcess
)

type Config struct {
Expand All @@ -20,6 +27,20 @@ type Config struct {
PathSuffix string
}

// AbstractProcess is an interface created to abstract operations of the gopsutil library
// It is created so that we can override the behavior while writing unit tests by providing
// a mock implementation.
type AbstractProcess interface {
Name() (string, error)
Parent() (AbstractProcess, error)
}

// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's
// process.Process object. This implementation is used in production code.
type RealProcess struct {
*process.Process
}

func GetShell(userShell string) (string, error) {
if userShell != "" {
if !isSupportedShell(userShell) {
Expand Down Expand Up @@ -151,3 +172,52 @@ func IsWindowsSubsystemLinux() bool {
}
return false
}

func (p *RealProcess) Parent() (AbstractProcess, error) {
parentProcess, err := p.Process.Parent()
if err != nil {
return nil, err
}
return &RealProcess{parentProcess}, nil
}

func createCurrentProcess() AbstractProcess {
currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid()))
if err != nil {
return nil
}
return &RealProcess{currentProcess}
}

// detectShellByCheckingProcessTree attempts to identify the shell being used by
// examining the process tree starting from the given process ID. This function
// traverses up to ProcessDepthLimit levels up the process hierarchy.
// Parameters:
// - pid (int): The process ID to start checking from.
//
// Returns:
// - string: The name of the shell if found (e.g., "zsh", "bash", "fish");
// otherwise, an empty string is returned if no matching shell is detected
// or an error occurs during the process tree traversal.
//
// Examples:
//
// shellName := detectShellByCheckingProcessTree(1234)
func detectShellByCheckingProcessTree(p AbstractProcess) string {
for p != nil {
processName, err := p.Name()
if err != nil {
return ""
}
if slices.ContainsFunc(supportedShell, func(listElem string) bool {
return strings.HasPrefix(processName, listElem)
}) {
cfergeau marked this conversation as resolved.
Show resolved Hide resolved
return processName
}
p, err = p.Parent()
if err != nil {
return ""
}
}
return ""
}
5 changes: 0 additions & 5 deletions pkg/os/shell/shell_darwin.go

This file was deleted.

5 changes: 0 additions & 5 deletions pkg/os/shell/shell_linux.go

This file was deleted.

36 changes: 36 additions & 0 deletions pkg/os/shell/shell_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package shell

import (
"errors"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -47,6 +48,28 @@ func (e *MockCommandRunner) RunPrivileged(_ string, cmdAndArgs ...string) (strin
return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn
}

// MockedProcess is a mock implementation of AbstractProcess for testing purposes.
type MockedProcess struct {
name string
parent *MockedProcess
nameGetFails bool
parentGetFails bool
}

func (m MockedProcess) Parent() (AbstractProcess, error) {
if m.parentGetFails || m.parent == nil {
return nil, errors.New("failed to get the pid")
}
return m.parent, nil
}

func (m MockedProcess) Name() (string, error) {
if m.nameGetFails {
return "", errors.New("failed to get the name")
}
return m.name, nil
}

func TestGetPathEnvString(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -179,3 +202,16 @@ func TestConvertToWindowsSubsystemLinuxPath(t *testing.T) {
assert.Equal(t, "wsl", mockCommandExecutor.commandName)
assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs)
}

func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess {
if len(processes) == 0 {
return nil
}
head := &processes[0]
current := head
for i := 1; i < len(processes); i++ {
current.parent = &processes[i]
current = current.parent
}
return head
}
69 changes: 1 addition & 68 deletions pkg/os/shell/shell_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,14 @@
package shell

import (
"errors"
"fmt"
"os"
"path/filepath"

"github.com/shirou/gopsutil/v4/process"
"github.com/spf13/cast"
)

var (
ErrUnknownShell = errors.New("Error: Unknown shell")
currentProcessSupplier = createCurrentProcess
supportedShell = []string{"bash", "zsh", "fish"}
)

// AbstractProcess is an interface created to abstract operations of the gopsutil library
// It is created so that we can override the behavior while writing unit tests by providing
// a mock implementation.
type AbstractProcess interface {
Name() (string, error)
Parent() (AbstractProcess, error)
}

// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's
// process.Process object. This implementation is used in production code.
type RealProcess struct {
*process.Process
}

func (p *RealProcess) Parent() (AbstractProcess, error) {
parentProcess, err := p.Process.Parent()
if err != nil {
return nil, err
}
return &RealProcess{parentProcess}, nil
}

func createCurrentProcess() AbstractProcess {
currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid()))
if err != nil {
return nil
}
return &RealProcess{currentProcess}
}

// detect detects user's current shell.
func detect() (string, error) {
detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier())
Expand All @@ -58,34 +22,3 @@ func detect() (string, error) {

return filepath.Base(detectedShell), nil
}

// detectShellByCheckingProcessTree attempts to identify the shell being used by
// examining the process tree starting from the given process ID. This function
// traverses up to ProcessDepthLimit levels up the process hierarchy.
// Parameters:
// - pid (int): The process ID to start checking from.
//
// Returns:
// - string: The name of the shell if found (e.g., "zsh", "bash", "fish");
// otherwise, an empty string is returned if no matching shell is detected
// or an error occurs during the process tree traversal.
//
// Examples:
//
// shellName := detectShellByCheckingProcessTree(1234)
func detectShellByCheckingProcessTree(p AbstractProcess) string {
for p != nil {
processName, err := p.Name()
if err != nil {
return ""
}
if processName == "zsh" || processName == "bash" || processName == "fish" {
return processName
}
p, err = p.Parent()
if err != nil {
return ""
}
}
return ""
}
36 changes: 0 additions & 36 deletions pkg/os/shell/shell_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,12 @@ package shell

import (
"bytes"
"errors"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

// MockedProcess is a mock implementation of AbstractProcess for testing purposes.
type MockedProcess struct {
name string
parent *MockedProcess
nameGetFails bool
parentGetFails bool
}

func (m MockedProcess) Parent() (AbstractProcess, error) {
if m.parentGetFails || m.parent == nil {
return nil, errors.New("failed to get the pid")
}
return m.parent, nil
}

func (m MockedProcess) Name() (string, error) {
if m.nameGetFails {
return "", errors.New("failed to get the name")
}
return m.name, nil
}

func TestUnknownShell(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -183,16 +160,3 @@ func TestGetCurrentProcess(t *testing.T) {
assert.NoError(t, err)
assert.Greater(t, len(currentProcessName), 0)
}

func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess {
if len(processes) == 0 {
return nil
}
head := &processes[0]
current := head
for i := 1; i < len(processes); i++ {
current.parent = &processes[i]
current = current.parent
}
return head
}
Loading