From ef1c7051c7d5aaab112d1d1ca5b35230e0ec366e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Wed, 22 Jan 2025 17:06:57 +0000 Subject: [PATCH] Create `mountpoint.Args` for parsing and accessing Mountpoint args (#349) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, accessing and manipulating of Mountpoint arguments was spread over to the codebase. This new mountpoint package contains all the logic related to Mountpoint arguments. Splitted out of https://github.com/awslabs/mountpoint-s3-csi-driver/pull/328. --- By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Signed-off-by: Burak Varlı --- .../csimounter/csimounter.go | 13 +- .../node/mounter/credential_provider.go | 13 +- .../node/mounter/credential_provider_test.go | 35 +- pkg/driver/node/mounter/fake_mounter.go | 4 +- pkg/driver/node/mounter/mocks/mock_mount.go | 9 +- pkg/driver/node/mounter/mount_credentials.go | 1 - pkg/driver/node/mounter/mounter.go | 3 +- pkg/driver/node/mounter/systemd_mounter.go | 61 +-- .../node/mounter/systemd_mounter_test.go | 38 +- pkg/driver/node/node.go | 24 +- pkg/driver/node/node_test.go | 9 +- pkg/mountpoint/args.go | 145 ++++++ pkg/mountpoint/args_test.go | 416 ++++++++++++++++++ 13 files changed, 625 insertions(+), 146 deletions(-) create mode 100644 pkg/mountpoint/args.go create mode 100644 pkg/mountpoint/args_test.go diff --git a/cmd/aws-s3-csi-mounter/csimounter/csimounter.go b/cmd/aws-s3-csi-mounter/csimounter/csimounter.go index 6b5ba551..84d5afd3 100644 --- a/cmd/aws-s3-csi-mounter/csimounter/csimounter.go +++ b/cmd/aws-s3-csi-mounter/csimounter/csimounter.go @@ -7,10 +7,10 @@ import ( "io/fs" "os" "os/exec" - "slices" "k8s.io/klog/v2" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mountoptions" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" ) @@ -52,20 +52,17 @@ func Run(options Options) (int, error) { return 0, fmt.Errorf("passed file descriptor %d is invalid", mountOptions.Fd) } - args := mountOptions.Args + mountpointArgs := mountpoint.ParseArgs(mountOptions.Args) // By default Mountpoint runs in a detached mode. Here we want to monitor it by relaying its output, // and also we want to wait until it terminates. We're passing `--foreground` to achieve this. - const foreground, foregroundShort = "--foreground", "-f" - if !(slices.Contains(args, foreground) || slices.Contains(args, foregroundShort)) { - args = append(args, foreground) - } + mountpointArgs.Set(mountpoint.ArgForeground, mountpoint.ArgNoValue) - args = append([]string{ + args := append([]string{ mountOptions.BucketName, // We pass FUSE fd using `ExtraFiles`, and each entry becomes as file descriptor 3+i. "/dev/fd/3", - }, args...) + }, mountpointArgs.SortedList()...) cmd := exec.Command(options.MountpointPath, args...) cmd.ExtraFiles = []*os.File{fuseDev} diff --git a/pkg/driver/node/mounter/credential_provider.go b/pkg/driver/node/mounter/credential_provider.go index 12cc4a82..4b90d0f3 100644 --- a/pkg/driver/node/mounter/credential_provider.go +++ b/pkg/driver/node/mounter/credential_provider.go @@ -22,6 +22,7 @@ import ( k8sstrings "k8s.io/utils/strings" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" ) const hostPluginDirEnv = "HOST_PLUGIN_DIR" @@ -84,7 +85,7 @@ func (c *CredentialProvider) CleanupToken(volumeID string, podID string) error { // Provide provides mount credentials for given volume and volume context. // Depending on the configuration, it either returns driver-level or pod-level credentials. -func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) { +func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeCtx map[string]string, args mountpoint.Args) (*MountCredentials, error) { if volumeCtx == nil { return nil, status.Error(codes.InvalidArgument, "Missing volume context") } @@ -92,7 +93,7 @@ func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volum authenticationSource := volumeCtx[volumecontext.AuthenticationSource] switch authenticationSource { case AuthenticationSourcePod: - return c.provideFromPod(ctx, volumeID, volumeCtx, mountpointArgs) + return c.provideFromPod(ctx, volumeID, volumeCtx, args) case AuthenticationSourceUnspecified, AuthenticationSourceDriver: return c.provideFromDriver() default: @@ -119,7 +120,7 @@ func (c *CredentialProvider) provideFromDriver() (*MountCredentials, error) { }, nil } -func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) { +func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeCtx map[string]string, args mountpoint.Args) (*MountCredentials, error) { klog.V(4).Infof("NodePublishVolume: Using pod identity") tokensJson := volumeCtx[volumecontext.CSIServiceAccountTokens] @@ -144,7 +145,7 @@ func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string return nil, err } - region, err := c.stsRegion(volumeCtx, mountpointArgs) + region, err := c.stsRegion(volumeCtx, args) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "Failed to detect STS AWS Region, please explicitly set the AWS Region, see "+stsConfigDocsPage) } @@ -245,14 +246,14 @@ func (c *CredentialProvider) findPodServiceAccountRole(ctx context.Context, volu // 4. Calling IMDS to detect region // // It returns an error if all of them fails. -func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, mountpointArgs []string) (string, error) { +func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, args mountpoint.Args) (string, error) { region := volumeCtx[volumecontext.STSRegion] if region != "" { klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from volume context", region) return region, nil } - if region, ok := ExtractMountpointArgument(mountpointArgs, mountpointArgRegion); ok { + if region, ok := args.Value(mountpoint.ArgRegion); ok { klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from S3 bucket region", region) return region, nil } diff --git a/pkg/driver/node/mounter/credential_provider_test.go b/pkg/driver/node/mounter/credential_provider_test.go index 26f5a7c9..ac8c23e2 100644 --- a/pkg/driver/node/mounter/credential_provider_test.go +++ b/pkg/driver/node/mounter/credential_provider_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -42,7 +43,7 @@ func TestProvidingDriverLevelCredentials(t *testing.T) { } { provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce) - credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, nil) + credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.AccessKeyID, "test-access-key") @@ -58,7 +59,7 @@ func TestProvidingDriverLevelCredentials(t *testing.T) { func TestProvidingDriverLevelCredentialsWithEmptyEnv(t *testing.T) { provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce) - credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{"authenticationSource": "driver"}, nil) + credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{"authenticationSource": "driver"}, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.AccessKeyID, "") @@ -93,7 +94,7 @@ func TestProvidingPodLevelCredentials(t *testing.T) { Token: "test-service-account-token", }, }), - }, nil) + }, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) // Should disable env variable provider @@ -216,7 +217,7 @@ func TestProvidingPodLevelCredentialsWithMissingInformation(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, nil) + credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, credentials) if err == nil { t.Error("it should fail with missing information") @@ -252,7 +253,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { return "", errors.New("unknown region") }) - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, credentials) if err == nil { t.Error("it should fail if there is not any region information") @@ -268,7 +269,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { return "us-east-1", nil }) - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "us-east-1") assertEquals(t, credentials.DefaultRegion, "us-east-1") @@ -286,7 +287,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "eu-west-1") assertEquals(t, credentials.DefaultRegion, "eu-west-1") @@ -304,7 +305,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_DEFAULT_REGION", "eu-west-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "eu-west-1") assertEquals(t, credentials.DefaultRegion, "eu-west-1") @@ -323,7 +324,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "eu-west-1") assertEquals(t, credentials.DefaultRegion, "eu-north-1") @@ -341,7 +342,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "us-west-1") assertEquals(t, credentials.DefaultRegion, "us-west-1") @@ -359,7 +360,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--read-only"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--read-only"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "eu-west-1") assertEquals(t, credentials.DefaultRegion, "eu-west-1") @@ -378,7 +379,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "us-west-1") assertEquals(t, credentials.DefaultRegion, "eu-north-1") @@ -398,7 +399,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { volumeContext["stsRegion"] = "ap-south-1" - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "ap-south-1") assertEquals(t, credentials.DefaultRegion, "ap-south-1") @@ -419,7 +420,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { volumeContext["stsRegion"] = "ap-south-1" - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "ap-south-1") assertEquals(t, credentials.DefaultRegion, "eu-north-1") @@ -457,7 +458,7 @@ func TestProvidingPodLevelCredentialsForDifferentPodsWithDifferentRoles(t *testi Token: "test-service-account-token-1", }, }), - }, nil) + }, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) credentialsPodTwo, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{ @@ -470,7 +471,7 @@ func TestProvidingPodLevelCredentialsForDifferentPodsWithDifferentRoles(t *testi Token: "test-service-account-token-2", }, }), - }, nil) + }, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) // PodOne @@ -526,7 +527,7 @@ func TestProvidingPodLevelCredentialsWithSlashInVolumeID(t *testing.T) { Token: "test-service-account-token", }, }), - }, nil) + }, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.AccessKeyID, "") diff --git a/pkg/driver/node/mounter/fake_mounter.go b/pkg/driver/node/mounter/fake_mounter.go index a33a3ec0..c97c40a0 100644 --- a/pkg/driver/node/mounter/fake_mounter.go +++ b/pkg/driver/node/mounter/fake_mounter.go @@ -1,9 +1,11 @@ package mounter +import "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" + type FakeMounter struct{} func (m *FakeMounter) Mount(bucketName string, target string, - credentials *MountCredentials, options []string) error { + credentials *MountCredentials, args mountpoint.Args) error { return nil } diff --git a/pkg/driver/node/mounter/mocks/mock_mount.go b/pkg/driver/node/mounter/mocks/mock_mount.go index 6d6053cc..fdfe8880 100644 --- a/pkg/driver/node/mounter/mocks/mock_mount.go +++ b/pkg/driver/node/mounter/mocks/mock_mount.go @@ -9,6 +9,7 @@ import ( reflect "reflect" mounter "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + mountpoint "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" system "github.com/awslabs/aws-s3-csi-driver/pkg/system" gomock "github.com/golang/mock/gomock" ) @@ -105,17 +106,17 @@ func (mr *MockMounterMockRecorder) IsMountPoint(target interface{}) *gomock.Call } // Mount mocks base method. -func (m *MockMounter) Mount(bucketName, target string, credentials *mounter.MountCredentials, options []string) error { +func (m *MockMounter) Mount(bucketName, target string, credentials *mounter.MountCredentials, args mountpoint.Args) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Mount", bucketName, target, credentials, options) + ret := m.ctrl.Call(m, "Mount", bucketName, target, credentials, args) ret0, _ := ret[0].(error) return ret0 } // Mount indicates an expected call of Mount. -func (mr *MockMounterMockRecorder) Mount(bucketName, target, credentials, options interface{}) *gomock.Call { +func (mr *MockMounterMockRecorder) Mount(bucketName, target, credentials, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), bucketName, target, credentials, options) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), bucketName, target, credentials, args) } // Unmount mocks base method. diff --git a/pkg/driver/node/mounter/mount_credentials.go b/pkg/driver/node/mounter/mount_credentials.go index 92cdf010..b9a87e52 100644 --- a/pkg/driver/node/mounter/mount_credentials.go +++ b/pkg/driver/node/mounter/mount_credentials.go @@ -20,7 +20,6 @@ const ( MountpointCacheKey = "UNSTABLE_MOUNTPOINT_CACHE_KEY" defaultMountS3Path = "/usr/bin/mount-s3" userAgentPrefix = "--user-agent-prefix" - awsMaxAttemptsOption = "--aws-max-attempts" ) type MountCredentials struct { diff --git a/pkg/driver/node/mounter/mounter.go b/pkg/driver/node/mounter/mounter.go index ece7b7c2..4ea5d82c 100644 --- a/pkg/driver/node/mounter/mounter.go +++ b/pkg/driver/node/mounter/mounter.go @@ -5,6 +5,7 @@ import ( "context" "os" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/system" ) @@ -15,7 +16,7 @@ type ServiceRunner interface { // Mounter is an interface for mount operations type Mounter interface { - Mount(bucketName string, target string, credentials *MountCredentials, options []string) error + Mount(bucketName string, target string, credentials *MountCredentials, args mountpoint.Args) error Unmount(target string) error IsMountPoint(target string) (bool, error) } diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index e6c0db66..d73978c7 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -5,10 +5,10 @@ import ( "fmt" "os" "path/filepath" - "strings" "time" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/system" "github.com/google/uuid" "k8s.io/klog/v2" @@ -79,7 +79,7 @@ func (m *SystemdMounter) IsMountPoint(target string) (bool, error) { // // This method will create the target path if it does not exist and if there is an existing corrupt // mount, it will attempt an unmount before attempting the mount. -func (m *SystemdMounter) Mount(bucketName string, target string, credentials *MountCredentials, options []string) error { +func (m *SystemdMounter) Mount(bucketName string, target string, credentials *MountCredentials, args mountpoint.Args) error { if bucketName == "" { return fmt.Errorf("bucket name is empty") } @@ -147,14 +147,14 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo env = credentials.Env(awsProfile) } - options, env = moveOptionToEnvironmentVariables(awsMaxAttemptsOption, awsMaxAttemptsEnv, options, env) - options = addUserAgentToOptions(options, UserAgent(authenticationSource, m.kubernetesVersion)) + args, env = moveArgumentsToEnv(args, env) + args.Set(mountpoint.ArgUserAgentPrefix, UserAgent(authenticationSource, m.kubernetesVersion)) output, err := m.Runner.StartService(timeoutCtx, &system.ExecConfig{ Name: "mount-s3-" + m.MpVersion + "-" + uuid.New().String() + ".service", Description: "Mountpoint for Amazon S3 CSI driver FUSE daemon", ExecPath: m.MountS3Path, - Args: append(options, bucketName, target), + Args: append(args.SortedList(), bucketName, target), Env: env, }) @@ -168,36 +168,11 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo return nil } -// Moves a parameter optionName from the options list to MP's environment variable list. We need this as options are -// passed to the driver in a single field, but MP sometimes only supports config from environment variables. -// Returns an updated options and environment. -func moveOptionToEnvironmentVariables(optionName string, envName string, options []string, env []string) ([]string, []string) { - optionIdx := -1 - for i, o := range options { - if strings.HasPrefix(o, optionName) { - optionIdx = i - break - } - } - if optionIdx != -1 { - // We can do replace here as we've just verified it has the right prefix - env = append(env, strings.Replace(options[optionIdx], optionName, envName, 1)) - options = append(options[:optionIdx], options[optionIdx+1:]...) - } - return options, env -} - -// method to add the user agent prefix to the Mountpoint headers -// https://github.com/awslabs/mountpoint-s3/pull/548 -func addUserAgentToOptions(options []string, userAgent string) []string { - // first remove it from the options in case it's in there - for i := len(options) - 1; i >= 0; i-- { - if strings.Contains(options[i], userAgentPrefix) { - options = append(options[:i], options[i+1:]...) - } +func moveArgumentsToEnv(args mountpoint.Args, env []string) (mountpoint.Args, []string) { + if maxAttempts, ok := args.Remove(mountpoint.ArgAWSMaxAttempts); ok { + env = append(env, fmt.Sprintf("%s=%s", awsMaxAttemptsEnv, maxAttempts)) } - // add the hard coded S3 CSI driver user agent string - return append(options, userAgentPrefix+"="+userAgent) + return args, env } func (m *SystemdMounter) Unmount(target string) error { @@ -224,21 +199,3 @@ func (m *SystemdMounter) Unmount(target string) error { } return nil } - -const ( - mountpointArgRegion = "region" - mountpointArgCache = "cache" -) - -// ExtractMountpointArgument extracts value of a given argument from `mountpointArgs`. -func ExtractMountpointArgument(mountpointArgs []string, argument string) (string, bool) { - // `mountpointArgs` normalized to `--arg=val` in `S3NodeServer.NodePublishVolume`. - prefix := fmt.Sprintf("--%s=", argument) - for _, arg := range mountpointArgs { - if strings.HasPrefix(arg, prefix) { - val := strings.SplitN(arg, "=", 2)[1] - return val, true - } - } - return "", false -} diff --git a/pkg/driver/node/mounter/systemd_mounter_test.go b/pkg/driver/node/mounter/systemd_mounter_test.go index 86c89689..6b3fffe0 100644 --- a/pkg/driver/node/mounter/systemd_mounter_test.go +++ b/pkg/driver/node/mounter/systemd_mounter_test.go @@ -13,6 +13,7 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/system" "github.com/golang/mock/gomock" "k8s.io/mount-utils" @@ -156,7 +157,7 @@ func TestS3MounterMount(t *testing.T) { testCase.before(t, env) } err := env.mounter.Mount(testCase.bucketName, testCase.targetPath, - testCase.credentials, testCase.options) + testCase.credentials, mountpoint.ParseArgs(testCase.options)) env.mockCtl.Finish() if err != nil && !testCase.expectedErr { t.Fatal(err) @@ -278,41 +279,6 @@ func TestProvidingEnvVariablesForMountpointProcess(t *testing.T) { } } -func TestExtractMountpointArgument(t *testing.T) { - for name, test := range map[string]struct { - input []string - argument string - expectedToFound bool - expectedValue string - }{ - "Extract Existing Argument": { - input: []string{ - "--region=us-east-1", - }, - argument: "region", - expectedToFound: true, - expectedValue: "us-east-1", - }, - "Extract Non Existing Argument": { - input: []string{ - "--bucket=test", - }, - argument: "region", - expectedToFound: false, - }, - "Extract Non Existing Argument With Empty Input": { - argument: "region", - expectedToFound: false, - }, - } { - t.Run(name, func(t *testing.T) { - val, found := mounter.ExtractMountpointArgument(test.input, test.argument) - assertEquals(t, test.expectedToFound, found) - assertEquals(t, test.expectedValue, val) - }) - } -} - func TestIsMountPoint(t *testing.T) { testDir := t.TempDir() mountpointS3MountPath := filepath.Join(testDir, "/var/lib/kubelet/pods/46efe8aa-75d9-4b12-8fdd-0ce0c2cabd99/volumes/kubernetes.io~csi/s3-mp-csi-pv/mount") diff --git a/pkg/driver/node/node.go b/pkg/driver/node/node.go index b26670e9..9ee0a01d 100644 --- a/pkg/driver/node/node.go +++ b/pkg/driver/node/node.go @@ -32,6 +32,7 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/targetpath" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" ) const ( @@ -121,34 +122,25 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl mountpointArgs := []string{} if req.GetReadonly() || volCap.GetAccessMode().GetMode() == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY { - mountpointArgs = append(mountpointArgs, "--read-only") + mountpointArgs = append(mountpointArgs, mountpoint.ArgReadOnly) } - // get the mount(point) options (yaml list) if capMount := volCap.GetMount(); capMount != nil { mountFlags := capMount.GetMountFlags() - for i := range mountFlags { - // trim left and right spaces - // trim spaces in between from multiple spaces to just one i.e. uid 1001 would turn into uid 1001 - // if there is a space between, replace it with an = sign - mountFlags[i] = strings.Replace(strings.Join(strings.Fields(strings.Trim(mountFlags[i], " ")), " "), " ", "=", -1) - // prepend -- if it's not already there - if !strings.HasPrefix(mountFlags[i], "-") { - mountFlags[i] = "--" + mountFlags[i] - } - } - mountpointArgs = compileMountOptions(mountpointArgs, mountFlags) + mountpointArgs = append(mountpointArgs, mountFlags...) } - credentials, err := ns.credentialProvider.Provide(ctx, req.VolumeId, req.VolumeContext, mountpointArgs) + args := mountpoint.ParseArgs(mountpointArgs) + + credentials, err := ns.credentialProvider.Provide(ctx, req.VolumeId, req.VolumeContext, args) if err != nil { klog.Errorf("NodePublishVolume: failed to provide credentials: %v", err) return nil, err } - klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, mountpointArgs) + klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, args.SortedList()) - if err := ns.Mounter.Mount(bucket, target, credentials, mountpointArgs); err != nil { + if err := ns.Mounter.Mount(bucket, target, credentials, args); err != nil { os.Remove(target) return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) } diff --git a/pkg/driver/node/node_test.go b/pkg/driver/node/node_test.go index 7db0178e..040f1635 100644 --- a/pkg/driver/node/node_test.go +++ b/pkg/driver/node/node_test.go @@ -11,6 +11,7 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" csi "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" @@ -100,7 +101,7 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--read-only"})) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq(mountpoint.ParseArgs([]string{"--read-only"}))) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -131,7 +132,7 @@ func TestNodePublishVolume(t *testing.T) { Readonly: true, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--bar", "--foo", "--read-only", "--test=123"})) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq(mountpoint.ParseArgs([]string{"--bar", "--foo", "--read-only", "--test=123"}))) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -164,7 +165,7 @@ func TestNodePublishVolume(t *testing.T) { nodeTestEnv.mockMounter.EXPECT().Mount( gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), - gomock.Eq([]string{"--read-only", "--test=123"})).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--read-only", "--test=123"}))).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -341,7 +342,7 @@ var _ mounter.Mounter = &dummyMounter{} type dummyMounter struct { } -func (d *dummyMounter) Mount(bucketName string, target string, credentials *mounter.MountCredentials, options []string) error { +func (d *dummyMounter) Mount(bucketName string, target string, credentials *mounter.MountCredentials, args mountpoint.Args) error { return nil } func (d *dummyMounter) Unmount(target string) error { diff --git a/pkg/mountpoint/args.go b/pkg/mountpoint/args.go new file mode 100644 index 00000000..4825484b --- /dev/null +++ b/pkg/mountpoint/args.go @@ -0,0 +1,145 @@ +package mountpoint + +import ( + "fmt" + "slices" + "strings" + + "k8s.io/apimachinery/pkg/util/sets" +) + +const ( + ArgForeground = "--foreground" + ArgReadOnly = "--read-only" + ArgAllowOther = "--allow-other" + ArgAllowRoot = "--allow-root" + ArgRegion = "--region" + ArgCache = "--cache" + ArgUserAgentPrefix = "--user-agent-prefix" + ArgAWSMaxAttempts = "--aws-max-attempts" +) + +// An ArgKey represents the key of an argument. +type ArgKey = string + +// An ArgValue represents the value of an argument. +type ArgValue = string + +// A value to use in arguments without any value, i.e., an option. +const ArgNoValue = "" + +// An arg represents an argument to be passed to Mountpoint. +type arg struct { + key ArgKey + value ArgValue +} + +// String returns string representation of the argument to pass Mountpoint. +func (a *arg) String() string { + if a.value == ArgNoValue { + return a.key + } + return fmt.Sprintf("%s=%s", a.key, a.value) +} + +// An Args represents arguments to be passed to Mountpoint during mount. +type Args struct { + args sets.Set[arg] +} + +// ParseArgs parses given list of unnormalized and returns a normalized [Args]. +func ParseArgs(passedArgs []string) Args { + args := sets.New[arg]() + + for _, a := range passedArgs { + var key, value string + + parts := strings.SplitN(strings.Trim(a, " "), "=", 2) + if len(parts) == 2 { + // Ex: `--key=value` or `key=value` + key, value = parts[0], parts[1] + } else { + // Ex: `--key value` or `key value` + // Ex: `--key` or `key` + parts = strings.SplitN(strings.Trim(parts[0], " "), " ", 2) + if len(parts) == 1 { + // Ex: `--key` or `key` + key = parts[0] + value = ArgNoValue + } else { + // Ex: `--key value` or `key value` + key, value = parts[0], strings.Trim(parts[1], " ") + } + } + + // prepend -- if it's not already there + key = normalizeKey(key) + + // disallow options that don't make sense in CSI + switch key { + case "--foreground", "-f", "--help", "-h", "--version", "-v": + continue + } + + args.Insert(arg{key, value}) + } + + return Args{args} +} + +// Set sets or replaces value of given key. +func (a *Args) Set(key ArgKey, value ArgValue) { + key = normalizeKey(key) + a.Remove(key) + a.args.Insert(arg{key, value}) +} + +// Value extracts value of given key, it returns extracted value and whether the key was found. +func (a *Args) Value(key ArgKey) (ArgValue, bool) { + arg, exists := a.find(key) + return arg.value, exists +} + +// Has returns whether given key exists in [Args]. +func (a *Args) Has(key ArgKey) bool { + _, exists := a.find(key) + return exists +} + +// Remove removes given key, it returns the key's value and whether the key was found. +func (a *Args) Remove(key ArgKey) (ArgValue, bool) { + arg, exists := a.find(key) + if exists { + a.args.Delete(arg) + } + return arg.value, exists +} + +// SortedList returns ordered list of normalized arguments. +func (a *Args) SortedList() []string { + args := make([]string, 0, a.args.Len()) + for _, arg := range a.args.UnsortedList() { + args = append(args, arg.String()) + } + slices.Sort(args) + return args +} + +// find tries to find given key from [Args], and returns whole entry, and whether the key was found. +func (a *Args) find(key ArgKey) (arg, bool) { + key = normalizeKey(key) + for _, arg := range a.args.UnsortedList() { + if key == arg.key { + return arg, true + } + } + return arg{}, false +} + +// normalizeKey normalized given key to have a "--" prefix. +func normalizeKey(key ArgKey) ArgKey { + if !strings.HasPrefix(key, "-") { + return "--" + key + } + return key +} diff --git a/pkg/mountpoint/args_test.go b/pkg/mountpoint/args_test.go new file mode 100644 index 00000000..d22318d1 --- /dev/null +++ b/pkg/mountpoint/args_test.go @@ -0,0 +1,416 @@ +package mountpoint_test + +import ( + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +func TestParsingMountpointArgs(t *testing.T) { + testCases := []struct { + name string + input []string + want []string + }{ + { + name: "no prefix", + input: []string{ + "allow-delete", + "region us-west-2", + "aws-max-attempts 5", + }, + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--region=us-west-2", + }, + }, + { + name: "with prefix", + input: []string{ + "--cache /tmp/s3-cache", + "--max-cache-size 500", + "--metadata-ttl 3", + }, + want: []string{ + "--cache=/tmp/s3-cache", + "--max-cache-size=500", + "--metadata-ttl=3", + }, + }, + { + name: "with equals but no prefix", + input: []string{ + "allow-delete", + "region=us-west-2", + "sse=aws:kms", + "sse-kms-key-id=arn:aws:kms:us-west-2:012345678900:key/00000000-0000-0000-0000-000000000000", + }, + want: []string{ + "--allow-delete", + "--region=us-west-2", + "--sse-kms-key-id=arn:aws:kms:us-west-2:012345678900:key/00000000-0000-0000-0000-000000000000", + "--sse=aws:kms", + }, + }, + { + name: "with equals and prefix", + input: []string{ + "--allow-other", + "--uid=1000", + "--gid=2000", + }, + want: []string{ + "--allow-other", + "--gid=2000", + "--uid=1000", + }, + }, + { + name: "with multiple spaces", + input: []string{ + "--allow-other", + "--uid 1000", + "--gid 2000", + }, + want: []string{ + "--allow-other", + "--gid=2000", + "--uid=1000", + }, + }, + { + name: "with spaces before and after", + input: []string{ + "--allow-other ", + " --uid 1000", + " --gid 2000 ", + }, + want: []string{ + "--allow-other", + "--gid=2000", + "--uid=1000", + }, + }, + { + name: "with single dash prefix", + input: []string{ + "-d", + "-l logs/", + }, + want: []string{ + "-d", + "-l=logs/", + }, + }, + { + name: "mixed prefix and equal signs", + input: []string{ + "--allow-delete", + "read-only", + "--cache=/tmp/s3-cache", + "--region us-east-1", + "prefix some-s3-prefix/", + "-d", + "-l=logs/", + }, + want: []string{ + "--allow-delete", + "--cache=/tmp/s3-cache", + "--prefix=some-s3-prefix/", + "--read-only", + "--region=us-east-1", + "-d", + "-l=logs/", + }, + }, + { + name: "with duplicated options", + input: []string{ + "--allow-other", + "--read-only", + "read-only", + "--allow-other", + }, + want: []string{ + "--allow-other", + "--read-only", + }, + }, + { + name: "with unsupported options", + input: []string{ + "--allow-other", + "--read-only", + "--foreground", "-f", + "--help", "-h", + "--version", "-v", + }, + want: []string{ + "--allow-other", + "--read-only", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.input) + assert.Equals(t, testCase.want, args.SortedList()) + }) + } +} + +func TestInsertingArgsToMountpointArgs(t *testing.T) { + testCases := []struct { + name string + existingArgs []string + key string + value string + want []string + }{ + { + name: "new option", + existingArgs: []string{ + "allow-delete", + "region us-west-2", + "aws-max-attempts 5", + }, + key: mountpoint.ArgReadOnly, + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--read-only", + "--region=us-west-2", + }, + }, + { + name: "existing option", + existingArgs: []string{ + "allow-delete", + "read-only", + "region us-west-2", + "aws-max-attempts 5", + }, + key: mountpoint.ArgReadOnly, + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--read-only", + "--region=us-west-2", + }, + }, + { + name: "new arg", + existingArgs: []string{ + "allow-delete", + "aws-max-attempts 5", + }, + key: mountpoint.ArgRegion, + value: "us-west-2", + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--region=us-west-2", + }, + }, + { + name: "existing arg", + existingArgs: []string{ + "allow-delete", + "read-only", + "region us-west-2", + "aws-max-attempts 5", + }, + key: mountpoint.ArgRegion, + value: "us-east-1", + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--read-only", + "--region=us-east-1", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.existingArgs) + args.Set(testCase.key, testCase.value) + assert.Equals(t, testCase.want, args.SortedList()) + }) + } +} + +func TestExtractingAnArgumentsValueFromMountpointArgs(t *testing.T) { + testCases := []struct { + name string + args []string + argToExtract string + want string + exists bool + }{ + { + name: "existing argument", + args: []string{ + "cache /tmp/s3-cache", + "region us-west-2", + "aws-max-attempts 5", + }, + argToExtract: mountpoint.ArgCache, + want: "/tmp/s3-cache", + exists: true, + }, + { + name: "existing argument with equal", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts 5", + }, + argToExtract: mountpoint.ArgRegion, + want: "us-west-2", + exists: true, + }, + { + name: "existing argument queried without prefix", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts 5", + }, + argToExtract: "region", + want: "us-west-2", + exists: true, + }, + { + name: "existing argument queried without equals", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts 5", + }, + argToExtract: "--region", + want: "us-west-2", + exists: true, + }, + { + name: "non-existent argument", + args: []string{ + "cache /tmp/s3-cache", + "aws-max-attempts 5", + }, + argToExtract: mountpoint.ArgRegion, + want: "", + exists: false, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.args) + gotValue, gotExists := args.Value(testCase.argToExtract) + assert.Equals(t, testCase.want, gotValue) + assert.Equals(t, testCase.exists, gotExists) + }) + } +} + +func TestRemovingAnArgumentFromMountpointArgs(t *testing.T) { + testCases := []struct { + name string + args []string + argToRemove string + want string + argsAfter []string + exists bool + }{ + { + name: "existing argument", + args: []string{ + "user-agent-prefix foo/bar", + "cache /tmp/s3-cache", + "region us-west-2", + "aws-max-attempts 5", + }, + argToRemove: mountpoint.ArgUserAgentPrefix, + want: "foo/bar", + exists: true, + argsAfter: []string{ + "--aws-max-attempts=5", + "--cache=/tmp/s3-cache", + "--region=us-west-2", + }, + }, + { + name: "existing argument without key prefix", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts=5", + }, + argToRemove: "aws-max-attempts", + want: "5", + exists: true, + argsAfter: []string{ + "--cache=/tmp/s3-cache", + "--region=us-west-2", + }, + }, + { + name: "non-existent argument", + args: []string{ + "cache /tmp/s3-cache", + "aws-max-attempts 5", + }, + argToRemove: mountpoint.ArgRegion, + want: "", + exists: false, + argsAfter: []string{ + "--aws-max-attempts=5", + "--cache=/tmp/s3-cache", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.args) + gotValue, gotExists := args.Remove(testCase.argToRemove) + assert.Equals(t, testCase.want, gotValue) + assert.Equals(t, testCase.exists, gotExists) + assert.Equals(t, testCase.argsAfter, args.SortedList()) + }) + } +} + +func TestQueryingExistenceOfAKeyInMountpointArgs(t *testing.T) { + args := mountpoint.ParseArgs([]string{ + "--allow-other", + "--cache /tmp/s3-cache", + "read-only", + }) + + assert.Equals(t, true, args.Has(mountpoint.ArgAllowOther)) + assert.Equals(t, true, args.Has(mountpoint.ArgCache)) + assert.Equals(t, true, args.Has(mountpoint.ArgReadOnly)) + assert.Equals(t, false, args.Has(mountpoint.ArgAllowRoot)) + assert.Equals(t, false, args.Has(mountpoint.ArgRegion)) +} + +func TestCreatingMountpointArgsFromAlreadyParsedArgs(t *testing.T) { + args := mountpoint.ParseArgs([]string{ + "--allow-other", + "--cache /tmp/s3-cache", + "read-only", + }) + args.Set("--user-agent-prefix", "s3-csi-driver/1.11.0 credential-source#pod k8s/v1.30.6-eks-7f9249a") + + want := []string{ + "--allow-other", + "--cache=/tmp/s3-cache", + "--read-only", + "--user-agent-prefix=s3-csi-driver/1.11.0 credential-source#pod k8s/v1.30.6-eks-7f9249a", + } + assert.Equals(t, want, args.SortedList()) + + parsedArgs := mountpoint.ParseArgs(args.SortedList()) + assert.Equals(t, want, parsedArgs.SortedList()) +}