diff --git a/charts/aws-s3-csi-driver/templates/node.yaml b/charts/aws-s3-csi-driver/templates/node.yaml index f074c5a1..a189f25f 100644 --- a/charts/aws-s3-csi-driver/templates/node.yaml +++ b/charts/aws-s3-csi-driver/templates/node.yaml @@ -61,6 +61,20 @@ spec: valueFrom: fieldRef: fieldPath: spec.nodeName + {{- with .Values.awsAccessSecret }} + - name: AWS_ACCESS_KEY_ID + valueFrom: + secretKeyRef: + name: {{ .name }} + key: {{ .keyId }} + optional: true + - name: AWS_SECRET_ACCESS_KEY + valueFrom: + secretKeyRef: + name: {{ .name }} + key: {{ .accessKey }} + optional: true + {{- end }} volumeMounts: - name: kubelet-dir mountPath: /var/lib/kubelet diff --git a/charts/aws-s3-csi-driver/values.yaml b/charts/aws-s3-csi-driver/values.yaml index d2ceca46..3d1f62d4 100644 --- a/charts/aws-s3-csi-driver/values.yaml +++ b/charts/aws-s3-csi-driver/values.yaml @@ -53,3 +53,8 @@ nameOverride: "" fullnameOverride: "" imagePullSecrets: [] + +awsAccessSecret: + name: aws-secret + keyId: key_id + accessKey: access_key \ No newline at end of file diff --git a/deploy/kubernetes/base/node-daemonset.yaml b/deploy/kubernetes/base/node-daemonset.yaml index 2f03116d..af3cfa99 100644 --- a/deploy/kubernetes/base/node-daemonset.yaml +++ b/deploy/kubernetes/base/node-daemonset.yaml @@ -49,6 +49,18 @@ spec: env: - name: CSI_ENDPOINT value: unix:/csi/csi.sock + - name: AWS_ACCESS_KEY_ID + valueFrom: + secretKeyRef: + name: aws-secret + key: key_id + optional: true + - name: AWS_SECRET_ACCESS_KEY + valueFrom: + secretKeyRef: + name: aws-secret + key: access_key + optional: true volumeMounts: - name: kubelet-dir mountPath: /var/lib/kubelet diff --git a/pkg/driver/mount.go b/pkg/driver/mount.go index a0f30953..54d8a59f 100644 --- a/pkg/driver/mount.go +++ b/pkg/driver/mount.go @@ -29,6 +29,11 @@ import ( "k8s.io/mount-utils" ) +const ( + keyIdEnv = "AWS_ACCESS_KEY_ID" + accessKeyEnv = "AWS_SECRET_ACCESS_KEY" +) + // Mounter is an interface for mount operations type Mounter interface { mount.Interface @@ -86,8 +91,15 @@ func (m *S3Mounter) PathExists(path string) (bool, error) { func (m *S3Mounter) Mount(source string, target string, _ string, options []string) error { timeoutCtx, cancel := context.WithTimeout(m.ctx, 30*time.Second) defer cancel() + keyId := os.Getenv(keyIdEnv) + accessKey := os.Getenv(accessKeyEnv) + env := []string{} + if keyId != "" && accessKey != "" { + env = append(env, keyIdEnv+"="+keyId) + env = append(env, accessKeyEnv+"="+accessKey) + } output, err := m.runner.Run(timeoutCtx, "/usr/bin/mount-s3", m.mpVersion+"-"+uuid.New().String(), - append([]string{source, target}, options...)) + env, append([]string{source, target}, options...)) if err != nil { return fmt.Errorf("Mount failed: %w mount-s3 output: %s", err, output) } diff --git a/pkg/driver/systemd.go b/pkg/driver/systemd.go index 2c908cba..ec5e75f4 100644 --- a/pkg/driver/systemd.go +++ b/pkg/driver/systemd.go @@ -60,7 +60,7 @@ func NewSystemdRunner() SystemdRunner { } // Run a given command in a transient systemd service. Will wait for the service to become active -func (sr *SystemdRunner) Run(ctx context.Context, cmd string, serviceTag string, args []string) (string, error) { +func (sr *SystemdRunner) Run(ctx context.Context, cmd string, serviceTag string, env []string, args []string) (string, error) { systemdConn, err := sr.Connector.Connect(ctx) if err != nil { return "", fmt.Errorf("Failed to connect to systemd: %w", err) @@ -83,6 +83,9 @@ func (sr *SystemdRunner) Run(ctx context.Context, cmd string, serviceTag string, {Name: "TTYPath", Value: dbus.MakeVariant(fmt.Sprintf("/dev/pts/%d", ptsN))}, systemd.PropExecStart(append([]string{cmd}, args...), true), } + if len(env) > 0 { + props = append(props, systemd.Property{Name: "Environment", Value: dbus.MakeVariant(env)}) + } // Unit names must be unique in systemd, so include a tag serviceName := filepath.Base(cmd) + "-" + serviceTag + ".service" diff --git a/pkg/driver/systemd_test.go b/pkg/driver/systemd_test.go index 15011ef9..cee864e7 100644 --- a/pkg/driver/systemd_test.go +++ b/pkg/driver/systemd_test.go @@ -3,6 +3,7 @@ package driver_test import ( "context" "errors" + "fmt" "io" "strings" "testing" @@ -10,6 +11,7 @@ import ( driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver" mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/mocks" systemd "github.com/coreos/go-systemd/v22/dbus" + dbus "github.com/godbus/dbus/v5" "github.com/golang/mock/gomock" ) @@ -46,7 +48,7 @@ func TestSystemdRunFailedConnection(t *testing.T) { runner := &driver.SystemdRunner{ Connector: mockConnector, } - out, err := runner.Run(ctx, "", "", nil) + out, err := runner.Run(ctx, "", "", nil, nil) if err == nil { t.Fatalf("Expected error on connection failure") } @@ -69,7 +71,7 @@ func TestSystemdRunNewPtsFailure(t *testing.T) { Connector: mockConnector, Pts: mockPts, } - out, err := runner.Run(ctx, "", "", nil) + out, err := runner.Run(ctx, "", "", nil, nil) if err == nil { t.Fatalf("Expected error on connection failure") } @@ -95,7 +97,7 @@ func TestSystemdStartUnitFailure(t *testing.T) { Connector: mockConnector, Pts: mockPts, } - out, err := runner.Run(ctx, "", "", nil) + out, err := runner.Run(ctx, "", "", nil, nil) if err == nil { t.Fatalf("Expected error on connection failure") } @@ -123,7 +125,7 @@ func TestSystemdRunCanceledContext(t *testing.T) { Connector: mockConnector, Pts: mockPts, } - out, err := runner.Run(ctx, "", "", nil) + out, err := runner.Run(ctx, "", "", nil, nil) if err == nil { t.Fatalf("Expected error on connection failure") } @@ -140,11 +142,23 @@ func TestSystemdRunSuccess(t *testing.T) { mockConnection.EXPECT().Close() mockConnector.EXPECT().Connect(gomock.Any()).Return(mockConnection, nil) testOutput := "testoutputdata" - mockPts.EXPECT().NewPts().Return(io.NopCloser(strings.NewReader(testOutput)), 0, nil) + ptsN := 5 + mockPts.EXPECT().NewPts().Return(io.NopCloser(strings.NewReader(testOutput)), 5, nil) + env := []string{"TEST=TEST"} + args := []string{"--test-arg1", "--test-arg2"} ctx := context.Background() + expectedProps := []systemd.Property{ + systemd.PropDescription("Mountpoint for S3 CSI driver FUSE daemon"), + systemd.PropType("forking"), + {Name: "StandardOutput", Value: dbus.MakeVariant("tty")}, + {Name: "StandardError", Value: dbus.MakeVariant("tty")}, + {Name: "TTYPath", Value: dbus.MakeVariant(fmt.Sprintf("/dev/pts/%d", ptsN))}, + systemd.PropExecStart(append([]string{testExe}, args...), true), + {Name: "Environment", Value: dbus.MakeVariant(env)}, + } mockConnection.EXPECT().StartTransientUnitContext( - gomock.Eq(ctx), gomock.Any(), gomock.Eq("fail"), gomock.Any(), gomock.Any()). + gomock.Eq(ctx), gomock.Any(), gomock.Eq("fail"), gomock.Eq(expectedProps), gomock.Any()). Do(func(_ context.Context, name string, _ string, _ []systemd.Property, ch chan<- string) { go func() { ch <- "done" }() }).Return(0, nil) @@ -161,7 +175,7 @@ func TestSystemdRunSuccess(t *testing.T) { Connector: mockConnector, Pts: mockPts, } - out, err := runner.Run(ctx, testExe, testTag, nil) + out, err := runner.Run(ctx, testExe, testTag, env, args) if err != nil { t.Fatalf("Unexpected error: %v", err) }