diff --git a/internal/k8s/exec_test.go b/internal/k8s/exec_test.go index e92daea..6976bcf 100644 --- a/internal/k8s/exec_test.go +++ b/internal/k8s/exec_test.go @@ -153,7 +153,7 @@ func TestIdledDeployLabels(t *testing.T) { t.Run(name, func(tt *testing.T) { // create fake Kubernetes client with test deploys c := &Client{ - clientset: fake.NewSimpleClientset(tc.deploys), + clientset: fake.NewClientset(tc.deploys), } deploys, err := c.idledDeploys(context.Background(), testNS) assert.NoError(tt, err, name) diff --git a/internal/k8s/logs.go b/internal/k8s/logs.go index 641ce76..a79daae 100644 --- a/internal/k8s/logs.go +++ b/internal/k8s/logs.go @@ -230,8 +230,7 @@ func (c *Client) Logs( } defer c.logSem.Release(1) // Wrap the context so we can cancel subroutines of this function on error. - childCtx, cancel := - context.WithTimeoutCause(ctx, c.logTimeLimit, ErrLogTimeLimit) + childCtx, cancel := context.WithTimeout(ctx, c.logTimeLimit) defer cancel() // Generate a requestID value to uniquely distinguish between multiple calls // to this function. This requestID is used in readLogs() to distinguish @@ -278,7 +277,7 @@ func (c *Client) Logs( return fmt.Errorf("couldn't construct new pod informer: %v", err) } podInformer.Run(childCtx.Done()) - if errors.Is(childCtx.Err(), ErrLogTimeLimit) { + if errors.Is(childCtx.Err(), context.DeadlineExceeded) { return ErrLogTimeLimit } return nil diff --git a/internal/k8s/logs_test.go b/internal/k8s/logs_test.go index 010b9eb..0a2fc26 100644 --- a/internal/k8s/logs_test.go +++ b/internal/k8s/logs_test.go @@ -1,6 +1,7 @@ package k8s import ( + "bytes" "context" "io" "strings" @@ -8,6 +9,12 @@ import ( "time" "github.com/alecthomas/assert/v2" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" ) func TestLinewiseCopy(t *testing.T) { @@ -44,3 +51,101 @@ func TestLinewiseCopy(t *testing.T) { }) } } + +func TestLogs(t *testing.T) { + testNS := "testns" + testDeploy := "foo" + testPod := "bar" + deploys := &appsv1.DeploymentList{ + Items: []appsv1.Deployment{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: testDeploy, + Namespace: testNS, + Labels: map[string]string{ + "idling.lagoon.sh/watch": "true", + }, + }, + Spec: appsv1.DeploymentSpec{ + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app.kubernetes.io/name": "foo-app", + }, + }, + }, + }, + }, + } + pods := &corev1.PodList{ + Items: []corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "foo-123xyz", + Namespace: testNS, + Labels: map[string]string{ + "app.kubernetes.io/name": "foo-app", + }, + }, + Status: corev1.PodStatus{ + ContainerStatuses: []corev1.ContainerStatus{ + { + Name: testPod, + }, + }, + }, + }, + }, + } + var testCases = map[string]struct { + follow bool + sessionCount uint + expectError bool + expectedError error + }{ + "no follow": { + sessionCount: 1, + }, + "no follow two sessions": { + sessionCount: 2, + }, + "no follow session count limit exceeded": { + sessionCount: 3, + expectError: true, + expectedError: ErrConcurrentLogLimit, + }, + "follow session timeout": { + follow: true, + sessionCount: 1, + expectError: true, + expectedError: ErrLogTimeLimit, + }, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + // create fake Kubernetes client with test deploys + c := &Client{ + clientset: fake.NewClientset(deploys, pods), + logSem: semaphore.NewWeighted(int64(2)), + logTimeLimit: time.Second, + } + // execute test + var buf bytes.Buffer + var eg errgroup.Group + ctx := context.Background() + for range tc.sessionCount { + eg.Go(func() error { + return c.Logs(ctx, testNS, testDeploy, testPod, tc.follow, 10, &buf) + }) + } + // check results + err := eg.Wait() + if tc.expectError { + assert.Error(tt, err, name) + assert.Equal(tt, err, tc.expectedError, name) + } else { + assert.NoError(tt, err, name) + tt.Log(buf.String()) + } + }) + } +}