Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more tests to improve coverage #414

Merged
merged 4 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion internal/k8s/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
"k8s.io/client-go/tools/remotecommand"
)

const (
idleAnnotation = "idling.amazee.io/unidle-replicas"
)

// podContainer returns the first pod and first container inside that pod for
// the given namespace and deployment.
func (c *Client) podContainer(ctx context.Context, namespace,
Expand Down Expand Up @@ -68,7 +72,7 @@ func (c *Client) hasRunningPod(ctx context.Context,
// replicas to restore. If the label cannot be read or parsed, 1 is returned.
// The return value is clamped to the interval [1,16].
func unidleReplicas(deploy appsv1.Deployment) int {
rs, ok := deploy.Annotations["idling.amazee.io/unidle-replicas"]
rs, ok := deploy.Annotations[idleAnnotation]
if !ok {
return 1
}
Expand Down
37 changes: 37 additions & 0 deletions internal/k8s/exec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package k8s

import (
"testing"

"github.com/alecthomas/assert/v2"
appsv1 "k8s.io/api/apps/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestUnidleReplicas(t *testing.T) {
var testCases = map[string]struct {
input string
expect int
}{
"simple": {input: "4", expect: 4},
"high edge": {input: "16", expect: 16},
"low edge": {input: "1", expect: 1},
"zero": {input: "0", expect: 1},
"too high": {input: "17", expect: 16},
"way too high": {input: "17000000", expect: 16},
"overflow too high": {input: "9223372036854775808", expect: 1},
"too low": {input: "-1", expect: 1},
"way too low": {input: "-17000000", expect: 1},
"overflow too low": {input: "-9223372036854775808", expect: 1},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
deploy := appsv1.Deployment{
ObjectMeta: metav1.ObjectMeta{
Annotations: map[string]string{idleAnnotation: tc.input},
},
}
assert.Equal(tt, tc.expect, unidleReplicas(deploy), name)
})
}
}
46 changes: 46 additions & 0 deletions internal/k8s/logs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package k8s

import (
"context"
"io"
"strings"
"testing"
"time"

"github.com/alecthomas/assert/v2"
)

func TestLinewiseCopy(t *testing.T) {
var testCases = map[string]struct {
input string
expect []string
prefix string
}{
"logs": {
input: "foo\nbar\nbaz\n",
expect: []string{"test: foo", "test: bar", "test: baz"},
prefix: "test:",
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
out := make(chan string, 1)
in := io.NopCloser(strings.NewReader(tc.input))
go linewiseCopy(ctx, tc.prefix, out, in)
timer := time.NewTimer(500 * time.Millisecond)
var lines []string
loop:
for {
select {
case <-timer.C:
break loop
case line := <-out:
lines = append(lines, line)
}
}
assert.Equal(tt, tc.expect, lines, name)
})
}
}
41 changes: 41 additions & 0 deletions internal/k8s/namespacedetails_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package k8s

import (
"testing"

"github.com/alecthomas/assert/v2"
)

func TestIntFromLabel(t *testing.T) {
labels := map[string]string{
"foo": "1",
"bar": "hello",
"baz": "true",
"negative": "-1",
"max": "9223372036854775807",
"overflow": "9223372036854775808",
}
var testCases = map[string]struct {
target string
expect int
expectErr bool
}{
"foo": {target: "foo", expect: 1},
"bar": {target: "bar", expectErr: true},
"baz": {target: "baz", expectErr: true},
"negative": {target: "negative", expect: -1},
"max": {target: "max", expect: 9223372036854775807},
"overflow": {target: "overflow", expectErr: true},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
result, err := intFromLabel(labels, tc.target)
if tc.expectErr {
assert.Error(tt, err, name)
} else {
assert.NoError(tt, err, name)
assert.Equal(tt, tc.expect, result, name)
}
})
}
}
37 changes: 37 additions & 0 deletions internal/k8s/spin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package k8s

import (
"context"
"strings"
"testing"
"time"

"github.com/alecthomas/assert/v2"
)

func TestSpinAfter(t *testing.T) {
wait := 500 * time.Millisecond
var testCases = map[string]struct {
connectTime time.Duration
expectSpinner bool
}{
"spinner": {connectTime: 600 * time.Millisecond, expectSpinner: true},
"no spinner": {connectTime: 400 * time.Millisecond, expectSpinner: false},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
var buf strings.Builder
// start the spinner with a given connect time
ctx, cancel := context.WithTimeout(context.Background(), tc.connectTime)
wg := spinAfter(ctx, &buf, wait)
wg.Wait()
cancel()
// check if the builder has spinner animations
if tc.expectSpinner {
assert.NotZero(tt, buf.Len(), name)
} else {
assert.Zero(tt, buf.Len(), name)
}
})
}
}
39 changes: 39 additions & 0 deletions internal/k8s/termsizequeue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package k8s

import (
"context"
"testing"

"github.com/alecthomas/assert/v2"
"github.com/gliderlabs/ssh"
"k8s.io/client-go/tools/remotecommand"
)

func TestTermSizeQueue(t *testing.T) {
var testCases = map[string]struct {
input ssh.Window
expect remotecommand.TerminalSize
}{
"term size change": {
input: ssh.Window{
Width: 100,
Height: 200,
},
expect: remotecommand.TerminalSize{
Width: 100,
Height: 200,
},
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
in := make(chan ssh.Window, 1)
tsq := newTermSizeQueue(ctx, in)
in <- tc.input
output := tsq.Next()
assert.Equal(tt, tc.expect, *output, name)
})
}
}
27 changes: 27 additions & 0 deletions internal/k8s/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package k8s_test

import (
"testing"

"github.com/alecthomas/assert/v2"
"github.com/uselagoon/ssh-portal/internal/k8s"
)

func TestValidateLabelValues(t *testing.T) {
var testCases = map[string]struct {
input string
expectError bool
}{
"valid": {input: "foo", expectError: false},
"invalid": {input: "naïve", expectError: true},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
if tc.expectError {
assert.Error(tt, k8s.ValidateLabelValue(tc.input), name)
} else {
assert.NoError(tt, k8s.ValidateLabelValue(tc.input), name)
}
})
}
}
9 changes: 7 additions & 2 deletions internal/rbac/usercansshtoenvironment.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ var defaultEnvTypeRoleCanSSH = map[lagoon.EnvironmentType][]lagoon.UserRole{
// UserCanSSHToEnvironment returns true if the given environment can be
// connected to via SSH by the user with the given realm roles and user groups,
// and false otherwise.
func (p *Permission) UserCanSSHToEnvironment(ctx context.Context, env *lagoondb.Environment,
realmRoles, userGroups []string, groupProjectIDs map[string][]int) bool {
func (p *Permission) UserCanSSHToEnvironment(
ctx context.Context,
env *lagoondb.Environment,
realmRoles,
userGroups []string,
groupProjectIDs map[string][]int,
) bool {
// set up tracing
_, span := otel.Tracer(pkgName).Start(ctx, "UserCanSSHToEnvironment")
defer span.End()
Expand Down
24 changes: 24 additions & 0 deletions internal/sshserver/serve_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package sshserver

import (
"slices"
"testing"

"github.com/alecthomas/assert/v2"
)

func TestDisableSHA1Kex(t *testing.T) {
var testCases = map[string]struct {
input string
expect bool
}{
"no sha1": {input: "diffie-hellman-group14-sha1", expect: false},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
conf := disableSHA1Kex(nil)
assert.Equal(tt, tc.expect,
slices.Contains(conf.Config.KeyExchanges, tc.input), name)
})
}
}
65 changes: 45 additions & 20 deletions internal/sshserver/sessionhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,42 @@ var (
})
)

// authCtxValues extracts the context values set by the authhandler.
func authCtxValues(ctx ssh.Context) (int, string, int, string, string, error) {
var ok bool
var eid, pid int
var ename, pname, fingerprint string
eid, ok = ctx.Value(environmentIDKey).(int)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract environment ID from session context")
}
ename, ok = ctx.Value(environmentNameKey).(string)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract environment name from session context")
}
pid, ok = ctx.Value(projectIDKey).(int)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract project ID from session context")
}
pname, ok = ctx.Value(projectNameKey).(string)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract project name from session context")
}
fingerprint, ok = ctx.Value(sshFingerprint).(string)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract SSH key fingerprint from session context")
}
return eid, ename, pid, pname, fingerprint, nil
}

// getSSHIntent analyses the SFTP flag and the raw command strings to determine
// if the command should be wrapped.
// if the command should be wrapped, and returns the given cmd wrapped
// appropriately.
func getSSHIntent(sftp bool, cmd []string) []string {
// if this is an sftp session we ignore any commands
if sftp {
Expand Down Expand Up @@ -104,25 +138,16 @@ func sessionHandler(log *slog.Logger, c K8SAPIService,
return
}
// extract info passed through the context by the authhandler
eid, ok := ctx.Value(environmentIDKey).(int)
if !ok {
log.Warn("couldn't extract environment ID from session context")
}
ename, ok := ctx.Value(environmentNameKey).(string)
if !ok {
log.Warn("couldn't extract environment name from session context")
}
pid, ok := ctx.Value(projectIDKey).(int)
if !ok {
log.Warn("couldn't extract project ID from session context")
}
pname, ok := ctx.Value(projectNameKey).(string)
if !ok {
log.Warn("couldn't extract project name from session context")
}
fingerprint, ok := ctx.Value(sshFingerprint).(string)
if !ok {
log.Warn("couldn't extract SSH key fingerprint from session context")
eid, ename, pid, pname, fingerprint, err := authCtxValues(ctx)
if err != nil {
log.Error("couldn't extract auth values from context",
slog.Any("error", err))
_, err = fmt.Fprintf(s.Stderr(), "error executing command. SID: %s\r\n",
ctx.SessionID())
if err != nil {
log.Debug("couldn't write to session stream", slog.Any("error", err))
}
return
}
if len(logs) != 0 {
if !logAccessEnabled {
Expand Down