From 104bca5f7683a896bf25895eacbbc9a9e490d827 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Mon, 23 Dec 2024 13:07:15 +0000 Subject: [PATCH 1/9] Create `mountpoint.Args` for parsing and accessing Mountpoint args MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Burak Varlı --- pkg/driver/node/mounter/fake_mounter.go | 4 +- pkg/driver/node/mounter/mocks/mock_mount.go | 9 +- pkg/driver/node/mounter/mount_credentials.go | 1 - pkg/driver/node/mounter/mounter.go | 3 +- pkg/driver/node/mounter/systemd_mounter.go | 45 +-- .../node/mounter/systemd_mounter_test.go | 3 +- pkg/driver/node/node.go | 20 +- pkg/driver/node/node_test.go | 9 +- pkg/mountpoint/args.go | 119 ++++++ pkg/mountpoint/args_test.go | 372 ++++++++++++++++++ 10 files changed, 529 insertions(+), 56 deletions(-) create mode 100644 pkg/mountpoint/args.go create mode 100644 pkg/mountpoint/args_test.go diff --git a/pkg/driver/node/mounter/fake_mounter.go b/pkg/driver/node/mounter/fake_mounter.go index a33a3ec0..c97c40a0 100644 --- a/pkg/driver/node/mounter/fake_mounter.go +++ b/pkg/driver/node/mounter/fake_mounter.go @@ -1,9 +1,11 @@ package mounter +import "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" + type FakeMounter struct{} func (m *FakeMounter) Mount(bucketName string, target string, - credentials *MountCredentials, options []string) error { + credentials *MountCredentials, args mountpoint.Args) error { return nil } diff --git a/pkg/driver/node/mounter/mocks/mock_mount.go b/pkg/driver/node/mounter/mocks/mock_mount.go index 6d6053cc..fdfe8880 100644 --- a/pkg/driver/node/mounter/mocks/mock_mount.go +++ b/pkg/driver/node/mounter/mocks/mock_mount.go @@ -9,6 +9,7 @@ import ( reflect "reflect" mounter "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + mountpoint "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" system "github.com/awslabs/aws-s3-csi-driver/pkg/system" gomock "github.com/golang/mock/gomock" ) @@ -105,17 +106,17 @@ func (mr *MockMounterMockRecorder) IsMountPoint(target interface{}) *gomock.Call } // Mount mocks base method. -func (m *MockMounter) Mount(bucketName, target string, credentials *mounter.MountCredentials, options []string) error { +func (m *MockMounter) Mount(bucketName, target string, credentials *mounter.MountCredentials, args mountpoint.Args) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Mount", bucketName, target, credentials, options) + ret := m.ctrl.Call(m, "Mount", bucketName, target, credentials, args) ret0, _ := ret[0].(error) return ret0 } // Mount indicates an expected call of Mount. -func (mr *MockMounterMockRecorder) Mount(bucketName, target, credentials, options interface{}) *gomock.Call { +func (mr *MockMounterMockRecorder) Mount(bucketName, target, credentials, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), bucketName, target, credentials, options) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), bucketName, target, credentials, args) } // Unmount mocks base method. diff --git a/pkg/driver/node/mounter/mount_credentials.go b/pkg/driver/node/mounter/mount_credentials.go index 92cdf010..b9a87e52 100644 --- a/pkg/driver/node/mounter/mount_credentials.go +++ b/pkg/driver/node/mounter/mount_credentials.go @@ -20,7 +20,6 @@ const ( MountpointCacheKey = "UNSTABLE_MOUNTPOINT_CACHE_KEY" defaultMountS3Path = "/usr/bin/mount-s3" userAgentPrefix = "--user-agent-prefix" - awsMaxAttemptsOption = "--aws-max-attempts" ) type MountCredentials struct { diff --git a/pkg/driver/node/mounter/mounter.go b/pkg/driver/node/mounter/mounter.go index ece7b7c2..4ea5d82c 100644 --- a/pkg/driver/node/mounter/mounter.go +++ b/pkg/driver/node/mounter/mounter.go @@ -5,6 +5,7 @@ import ( "context" "os" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/system" ) @@ -15,7 +16,7 @@ type ServiceRunner interface { // Mounter is an interface for mount operations type Mounter interface { - Mount(bucketName string, target string, credentials *MountCredentials, options []string) error + Mount(bucketName string, target string, credentials *MountCredentials, args mountpoint.Args) error Unmount(target string) error IsMountPoint(target string) (bool, error) } diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index e6c0db66..a7d49b0f 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -9,6 +9,7 @@ import ( "time" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/system" "github.com/google/uuid" "k8s.io/klog/v2" @@ -79,7 +80,7 @@ func (m *SystemdMounter) IsMountPoint(target string) (bool, error) { // // This method will create the target path if it does not exist and if there is an existing corrupt // mount, it will attempt an unmount before attempting the mount. -func (m *SystemdMounter) Mount(bucketName string, target string, credentials *MountCredentials, options []string) error { +func (m *SystemdMounter) Mount(bucketName string, target string, credentials *MountCredentials, args mountpoint.Args) error { if bucketName == "" { return fmt.Errorf("bucket name is empty") } @@ -147,14 +148,14 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo env = credentials.Env(awsProfile) } - options, env = moveOptionToEnvironmentVariables(awsMaxAttemptsOption, awsMaxAttemptsEnv, options, env) - options = addUserAgentToOptions(options, UserAgent(authenticationSource, m.kubernetesVersion)) + args, env = moveArgumentsToEnv(args, env) + args = addUserAgentToArguments(args, UserAgent(authenticationSource, m.kubernetesVersion)) output, err := m.Runner.StartService(timeoutCtx, &system.ExecConfig{ Name: "mount-s3-" + m.MpVersion + "-" + uuid.New().String() + ".service", Description: "Mountpoint for Amazon S3 CSI driver FUSE daemon", ExecPath: m.MountS3Path, - Args: append(options, bucketName, target), + Args: append(args.SortedList(), bucketName, target), Env: env, }) @@ -168,36 +169,20 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo return nil } -// Moves a parameter optionName from the options list to MP's environment variable list. We need this as options are -// passed to the driver in a single field, but MP sometimes only supports config from environment variables. -// Returns an updated options and environment. -func moveOptionToEnvironmentVariables(optionName string, envName string, options []string, env []string) ([]string, []string) { - optionIdx := -1 - for i, o := range options { - if strings.HasPrefix(o, optionName) { - optionIdx = i - break - } - } - if optionIdx != -1 { - // We can do replace here as we've just verified it has the right prefix - env = append(env, strings.Replace(options[optionIdx], optionName, envName, 1)) - options = append(options[:optionIdx], options[optionIdx+1:]...) +func moveArgumentsToEnv(args mountpoint.Args, env []string) (mountpoint.Args, []string) { + if maxAttempts, ok := args.Remove(mountpoint.ArgAWSMaxAttempts); ok { + env = append(env, fmt.Sprintf("%s=%s", awsMaxAttemptsEnv, maxAttempts)) } - return options, env + return args, env } -// method to add the user agent prefix to the Mountpoint headers +// method to add the user agent prefix to the Mountpoint arguments. // https://github.com/awslabs/mountpoint-s3/pull/548 -func addUserAgentToOptions(options []string, userAgent string) []string { - // first remove it from the options in case it's in there - for i := len(options) - 1; i >= 0; i-- { - if strings.Contains(options[i], userAgentPrefix) { - options = append(options[:i], options[i+1:]...) - } - } - // add the hard coded S3 CSI driver user agent string - return append(options, userAgentPrefix+"="+userAgent) +func addUserAgentToArguments(args mountpoint.Args, userAgent string) mountpoint.Args { + // Remove existing user-agent if provided to ensure we always use the correct user-agent + _, _ = args.Remove(mountpoint.ArgUserAgentPrefix) + args.Insert(fmt.Sprintf("%s=%s", mountpoint.ArgUserAgentPrefix, userAgent)) + return args } func (m *SystemdMounter) Unmount(target string) error { diff --git a/pkg/driver/node/mounter/systemd_mounter_test.go b/pkg/driver/node/mounter/systemd_mounter_test.go index 86c89689..aec26d53 100644 --- a/pkg/driver/node/mounter/systemd_mounter_test.go +++ b/pkg/driver/node/mounter/systemd_mounter_test.go @@ -13,6 +13,7 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/system" "github.com/golang/mock/gomock" "k8s.io/mount-utils" @@ -156,7 +157,7 @@ func TestS3MounterMount(t *testing.T) { testCase.before(t, env) } err := env.mounter.Mount(testCase.bucketName, testCase.targetPath, - testCase.credentials, testCase.options) + testCase.credentials, mountpoint.ParseArgs(testCase.options)) env.mockCtl.Finish() if err != nil && !testCase.expectedErr { t.Fatal(err) diff --git a/pkg/driver/node/node.go b/pkg/driver/node/node.go index b26670e9..4930d5b6 100644 --- a/pkg/driver/node/node.go +++ b/pkg/driver/node/node.go @@ -32,6 +32,7 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/targetpath" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" ) const ( @@ -121,25 +122,16 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl mountpointArgs := []string{} if req.GetReadonly() || volCap.GetAccessMode().GetMode() == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY { - mountpointArgs = append(mountpointArgs, "--read-only") + mountpointArgs = append(mountpointArgs, mountpoint.ArgReadOnly) } - // get the mount(point) options (yaml list) if capMount := volCap.GetMount(); capMount != nil { mountFlags := capMount.GetMountFlags() - for i := range mountFlags { - // trim left and right spaces - // trim spaces in between from multiple spaces to just one i.e. uid 1001 would turn into uid 1001 - // if there is a space between, replace it with an = sign - mountFlags[i] = strings.Replace(strings.Join(strings.Fields(strings.Trim(mountFlags[i], " ")), " "), " ", "=", -1) - // prepend -- if it's not already there - if !strings.HasPrefix(mountFlags[i], "-") { - mountFlags[i] = "--" + mountFlags[i] - } - } - mountpointArgs = compileMountOptions(mountpointArgs, mountFlags) + mountpointArgs = append(mountpointArgs, mountFlags...) } + args := mountpoint.ParseArgs(mountpointArgs) + credentials, err := ns.credentialProvider.Provide(ctx, req.VolumeId, req.VolumeContext, mountpointArgs) if err != nil { klog.Errorf("NodePublishVolume: failed to provide credentials: %v", err) @@ -148,7 +140,7 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, mountpointArgs) - if err := ns.Mounter.Mount(bucket, target, credentials, mountpointArgs); err != nil { + if err := ns.Mounter.Mount(bucket, target, credentials, args); err != nil { os.Remove(target) return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) } diff --git a/pkg/driver/node/node_test.go b/pkg/driver/node/node_test.go index 7db0178e..040f1635 100644 --- a/pkg/driver/node/node_test.go +++ b/pkg/driver/node/node_test.go @@ -11,6 +11,7 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" csi "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" @@ -100,7 +101,7 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--read-only"})) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq(mountpoint.ParseArgs([]string{"--read-only"}))) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -131,7 +132,7 @@ func TestNodePublishVolume(t *testing.T) { Readonly: true, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--bar", "--foo", "--read-only", "--test=123"})) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq(mountpoint.ParseArgs([]string{"--bar", "--foo", "--read-only", "--test=123"}))) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -164,7 +165,7 @@ func TestNodePublishVolume(t *testing.T) { nodeTestEnv.mockMounter.EXPECT().Mount( gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), - gomock.Eq([]string{"--read-only", "--test=123"})).Return(nil) + gomock.Eq(mountpoint.ParseArgs([]string{"--read-only", "--test=123"}))).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -341,7 +342,7 @@ var _ mounter.Mounter = &dummyMounter{} type dummyMounter struct { } -func (d *dummyMounter) Mount(bucketName string, target string, credentials *mounter.MountCredentials, options []string) error { +func (d *dummyMounter) Mount(bucketName string, target string, credentials *mounter.MountCredentials, args mountpoint.Args) error { return nil } func (d *dummyMounter) Unmount(target string) error { diff --git a/pkg/mountpoint/args.go b/pkg/mountpoint/args.go new file mode 100644 index 00000000..988dc273 --- /dev/null +++ b/pkg/mountpoint/args.go @@ -0,0 +1,119 @@ +package mountpoint + +import ( + "strings" + + "k8s.io/apimachinery/pkg/util/sets" +) + +const ( + ArgForeground = "--foreground" + ArgReadOnly = "--read-only" + ArgAllowOther = "--allow-other" + ArgAllowRoot = "--allow-root" + ArgRegion = "--region" + ArgCache = "--cache" + ArgUserAgentPrefix = "--user-agent-prefix" + ArgAWSMaxAttempts = "--aws-max-attempts" +) + +// An Args represents arguments to be passed to Mountpoint during mount. +type Args struct { + args sets.Set[string] +} + +// ParseArgs parses given list of unnormalized and returns a normalized [Args]. +func ParseArgs(passedArgs []string) Args { + args := sets.New[string]() + + for _, arg := range passedArgs { + // trim left and right spaces + // trim spaces in between from multiple spaces to just one i.e. uid 1001 would turn into uid 1001 + // if there is a space between, replace it with an = sign + arg = strings.Replace(strings.Join(strings.Fields(strings.Trim(arg, " ")), " "), " ", "=", -1) + // prepend -- if it's not already there + if !strings.HasPrefix(arg, "-") { + arg = "--" + arg + } + + // disallow options that don't make sense in CSI + switch arg { + case "--foreground", "-f", "--help", "-h", "--version", "-v": + continue + } + + args.Insert(arg) + } + + return Args{args} +} + +// ParsedArgs creates [Args] from already parsed arguments by [ParseArgs]. +func ParsedArgs(parsedArgs []string) Args { + return Args{args: sets.New(parsedArgs...)} +} + +// Insert inserts given normalized argument to [Args] if not exists. +func (a *Args) Insert(arg string) { + a.args.Insert(arg) +} + +// Value extracts value of given key, it returns extracted value and whether the key was found. +func (a *Args) Value(key string) (string, bool) { + _, val, exists := a.find(key) + return val, exists +} + +// Has returns whether given key exists in [Args]. +func (a *Args) Has(key string) bool { + _, _, exists := a.find(key) + return exists +} + +// Remove removes given key, it returns the key's value and whether the key was found. +func (a *Args) Remove(key string) (string, bool) { + entry, val, exists := a.find(key) + if exists { + a.args.Delete(entry) + } + return val, exists +} + +// SortedList returns ordered list of normalized arguments. +func (a *Args) SortedList() []string { + return sets.List(a.args) +} + +// find tries to find given key from [Args], and returns whole entry, value and whether the key was found. +func (a *Args) find(key string) (string, string, bool) { + key, prefix := a.keysForSearch(key) + + for _, arg := range a.args.UnsortedList() { + if key == arg { + return key, "", true + } + + if strings.HasPrefix(arg, prefix) { + val := strings.SplitN(arg, "=", 2)[1] + return arg, val, true + } + } + + return "", "", false +} + +// keysForSearch returns whole key and a prefix to search for given key in [Args]. +// First one is the whole key to look for without `=` at the end for option-like arguments without any value. +// Second one is a prefix with `=` at the end for arguments with value. +// +// Arguments are normalized to `--key[=value]` in [ParseArgs], here this function also makes sure +// the returned prefixes have the same prefix format for the given key. +func (a *Args) keysForSearch(key string) (string, string) { + prefix := strings.TrimSuffix(key, "=") + + if !strings.HasPrefix(key, "-") { + prefix = "--" + prefix + } + + return prefix, prefix + "=" +} diff --git a/pkg/mountpoint/args_test.go b/pkg/mountpoint/args_test.go new file mode 100644 index 00000000..7af819c0 --- /dev/null +++ b/pkg/mountpoint/args_test.go @@ -0,0 +1,372 @@ +package mountpoint_test + +import ( + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +func TestParsingMountpointArgs(t *testing.T) { + testCases := []struct { + name string + input []string + want []string + }{ + { + name: "no prefix", + input: []string{ + "allow-delete", + "region us-west-2", + "aws-max-attempts 5", + }, + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--region=us-west-2", + }, + }, + { + name: "with prefix", + input: []string{ + "--cache /tmp/s3-cache", + "--max-cache-size 500", + "--metadata-ttl 3", + }, + want: []string{ + "--cache=/tmp/s3-cache", + "--max-cache-size=500", + "--metadata-ttl=3", + }, + }, + { + name: "with equals but no prefix", + input: []string{ + "allow-delete", + "region=us-west-2", + "sse=aws:kms", + "sse-kms-key-id=arn:aws:kms:us-west-2:012345678900:key/00000000-0000-0000-0000-000000000000", + }, + want: []string{ + "--allow-delete", + "--region=us-west-2", + "--sse-kms-key-id=arn:aws:kms:us-west-2:012345678900:key/00000000-0000-0000-0000-000000000000", + "--sse=aws:kms", + }, + }, + { + name: "with equals and prefix", + input: []string{ + "--allow-other", + "--uid=1000", + "--gid=2000", + }, + want: []string{ + "--allow-other", + "--gid=2000", + "--uid=1000", + }, + }, + { + name: "with multiple spaces", + input: []string{ + "--allow-other", + "--uid 1000", + "--gid 2000", + }, + want: []string{ + "--allow-other", + "--gid=2000", + "--uid=1000", + }, + }, + { + name: "with single dash prefix", + input: []string{ + "-d", + "-l logs/", + }, + want: []string{ + "-d", + "-l=logs/", + }, + }, + { + name: "mixed prefix and equal signs", + input: []string{ + "--allow-delete", + "read-only", + "--cache=/tmp/s3-cache", + "--region us-east-1", + "prefix some-s3-prefix/", + "-d", + "-l=logs/", + }, + want: []string{ + "--allow-delete", + "--cache=/tmp/s3-cache", + "--prefix=some-s3-prefix/", + "--read-only", + "--region=us-east-1", + "-d", + "-l=logs/", + }, + }, + { + name: "with duplicated options", + input: []string{ + "--allow-other", + "--read-only", + "read-only", + "--allow-other", + }, + want: []string{ + "--allow-other", + "--read-only", + }, + }, + { + name: "with unsupported options", + input: []string{ + "--allow-other", + "--read-only", + "--foreground", "-f", + "--help", "-h", + "--version", "-v", + }, + want: []string{ + "--allow-other", + "--read-only", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.input) + assert.Equals(t, testCase.want, args.SortedList()) + }) + } +} + +func TestInsertingNewArgsToMountpointArgs(t *testing.T) { + testCases := []struct { + name string + existingArgs []string + newArg string + want []string + }{ + { + name: "new arg", + existingArgs: []string{ + "allow-delete", + "region us-west-2", + "aws-max-attempts 5", + }, + newArg: mountpoint.ArgReadOnly, + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--read-only", + "--region=us-west-2", + }, + }, + { + name: "existing arg", + existingArgs: []string{ + "allow-delete", + "read-only", + "region us-west-2", + "aws-max-attempts 5", + }, + newArg: mountpoint.ArgReadOnly, + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--read-only", + "--region=us-west-2", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.existingArgs) + args.Insert(testCase.newArg) + assert.Equals(t, testCase.want, args.SortedList()) + }) + } +} + +func TestExtractingAnArgumentsValueFromMountpointArgs(t *testing.T) { + testCases := []struct { + name string + args []string + argToExtract string + want string + exists bool + }{ + { + name: "existing argument", + args: []string{ + "cache /tmp/s3-cache", + "region us-west-2", + "aws-max-attempts 5", + }, + argToExtract: mountpoint.ArgCache, + want: "/tmp/s3-cache", + exists: true, + }, + { + name: "existing argument with equal", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts 5", + }, + argToExtract: mountpoint.ArgRegion, + want: "us-west-2", + exists: true, + }, + { + name: "existing argument queried without prefix", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts 5", + }, + argToExtract: "region", + want: "us-west-2", + exists: true, + }, + { + name: "existing argument queried without equals", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts 5", + }, + argToExtract: "--region", + want: "us-west-2", + exists: true, + }, + { + name: "non-existent argument", + args: []string{ + "cache /tmp/s3-cache", + "aws-max-attempts 5", + }, + argToExtract: mountpoint.ArgRegion, + want: "", + exists: false, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.args) + gotValue, gotExists := args.Value(testCase.argToExtract) + assert.Equals(t, testCase.want, gotValue) + assert.Equals(t, testCase.exists, gotExists) + }) + } +} + +func TestRemovingAnArgumentFromMountpointArgs(t *testing.T) { + testCases := []struct { + name string + args []string + argToRemove string + want string + argsAfter []string + exists bool + }{ + { + name: "existing argument", + args: []string{ + "user-agent-prefix foo/bar", + "cache /tmp/s3-cache", + "region us-west-2", + "aws-max-attempts 5", + }, + argToRemove: mountpoint.ArgUserAgentPrefix, + want: "foo/bar", + exists: true, + argsAfter: []string{ + "--aws-max-attempts=5", + "--cache=/tmp/s3-cache", + "--region=us-west-2", + }, + }, + { + name: "existing argument with equal", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts=5", + }, + argToRemove: mountpoint.ArgAWSMaxAttempts, + want: "5", + exists: true, + argsAfter: []string{ + "--cache=/tmp/s3-cache", + "--region=us-west-2", + }, + }, + { + name: "non-existent argument", + args: []string{ + "cache /tmp/s3-cache", + "aws-max-attempts 5", + }, + argToRemove: mountpoint.ArgRegion, + want: "", + exists: false, + argsAfter: []string{ + "--aws-max-attempts=5", + "--cache=/tmp/s3-cache", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.args) + gotValue, gotExists := args.Remove(testCase.argToRemove) + assert.Equals(t, testCase.want, gotValue) + assert.Equals(t, testCase.exists, gotExists) + assert.Equals(t, testCase.argsAfter, args.SortedList()) + }) + } +} + +func TestQueryingExistenceOfAKeyInMountpointArgs(t *testing.T) { + args := mountpoint.ParseArgs([]string{ + "--allow-other", + "--cache /tmp/s3-cache", + "read-only", + }) + + assert.Equals(t, true, args.Has(mountpoint.ArgAllowOther)) + assert.Equals(t, true, args.Has(mountpoint.ArgCache)) + assert.Equals(t, true, args.Has(mountpoint.ArgReadOnly)) + assert.Equals(t, false, args.Has(mountpoint.ArgAllowRoot)) + assert.Equals(t, false, args.Has(mountpoint.ArgRegion)) +} + +func TestCreatingMountpointArgsFromAlreadyParsedArgs(t *testing.T) { + args := mountpoint.ParseArgs([]string{ + "--allow-other", + "--cache /tmp/s3-cache", + "read-only", + }) + args.Insert("--user-agent-prefix=s3-csi-driver/1.11.0 credential-source#pod k8s/v1.30.6-eks-7f9249a") + + want := []string{ + "--allow-other", + "--cache=/tmp/s3-cache", + "--read-only", + "--user-agent-prefix=s3-csi-driver/1.11.0 credential-source#pod k8s/v1.30.6-eks-7f9249a", + } + + assert.Equals(t, want, args.SortedList()) + + parsedArgs := mountpoint.ParsedArgs(args.SortedList()) + assert.Equals(t, want, parsedArgs.SortedList()) +} From 8e982041ff5a0f90bd8d0586226f1f74a27ba291 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Mon, 20 Jan 2025 13:31:17 +0000 Subject: [PATCH 2/9] Store arguments as key value pairs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Burak Varlı --- pkg/driver/node/mounter/systemd_mounter.go | 3 +- pkg/mountpoint/args.go | 133 ++++++++++++--------- pkg/mountpoint/args_test.go | 55 +++++++-- 3 files changed, 123 insertions(+), 68 deletions(-) diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index a7d49b0f..c4254106 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -180,8 +180,7 @@ func moveArgumentsToEnv(args mountpoint.Args, env []string) (mountpoint.Args, [] // https://github.com/awslabs/mountpoint-s3/pull/548 func addUserAgentToArguments(args mountpoint.Args, userAgent string) mountpoint.Args { // Remove existing user-agent if provided to ensure we always use the correct user-agent - _, _ = args.Remove(mountpoint.ArgUserAgentPrefix) - args.Insert(fmt.Sprintf("%s=%s", mountpoint.ArgUserAgentPrefix, userAgent)) + args.Set(mountpoint.ArgUserAgentPrefix, userAgent) return args } diff --git a/pkg/mountpoint/args.go b/pkg/mountpoint/args.go index 988dc273..f7bd20da 100644 --- a/pkg/mountpoint/args.go +++ b/pkg/mountpoint/args.go @@ -1,6 +1,8 @@ package mountpoint import ( + "fmt" + "slices" "strings" "k8s.io/apimachinery/pkg/util/sets" @@ -17,103 +19,126 @@ const ( ArgAWSMaxAttempts = "--aws-max-attempts" ) +// An ArgKey represents the key of an argument. +type ArgKey = string + +// An ArgValue represents the value of an argument. +type ArgValue = string + +// A value to use in arguments without any value, i.e., an option. +const ArgNoValue = "" + +// An Arg represents an argument to be passed to Mountpoint. +type Arg struct { + key ArgKey + value ArgValue +} + +// String returns string representation of the argument to pass Mountpoint. +func (a *Arg) String() string { + if a.value == ArgNoValue { + return a.key + } + return fmt.Sprintf("%s=%s", a.key, a.value) +} + // An Args represents arguments to be passed to Mountpoint during mount. type Args struct { - args sets.Set[string] + args sets.Set[Arg] } // ParseArgs parses given list of unnormalized and returns a normalized [Args]. func ParseArgs(passedArgs []string) Args { - args := sets.New[string]() + args := sets.New[Arg]() for _, arg := range passedArgs { - // trim left and right spaces - // trim spaces in between from multiple spaces to just one i.e. uid 1001 would turn into uid 1001 - // if there is a space between, replace it with an = sign - arg = strings.Replace(strings.Join(strings.Fields(strings.Trim(arg, " ")), " "), " ", "=", -1) - // prepend -- if it's not already there - if !strings.HasPrefix(arg, "-") { - arg = "--" + arg + var key, value string + + parts := strings.SplitN(strings.Trim(arg, " "), "=", 2) + if len(parts) == 2 { + // Ex: --key=value, key=value + key, value = parts[0], parts[1] + } else { + // Ex: --key value, key value + // Ex: --key, key + parts = strings.SplitN(strings.Trim(parts[0], " "), " ", 2) + if len(parts) == 1 { + // Ex: --key, key + key = parts[0] + } else { + // Ex: --key value, key value + key, value = parts[0], strings.Trim(parts[1], " ") + } } + // prepend -- if it's not already there + key = normalizeKey(key) + // disallow options that don't make sense in CSI - switch arg { + switch key { case "--foreground", "-f", "--help", "-h", "--version", "-v": continue } - args.Insert(arg) + args.Insert(Arg{key, value}) } return Args{args} } -// ParsedArgs creates [Args] from already parsed arguments by [ParseArgs]. -func ParsedArgs(parsedArgs []string) Args { - return Args{args: sets.New(parsedArgs...)} -} - -// Insert inserts given normalized argument to [Args] if not exists. -func (a *Args) Insert(arg string) { - a.args.Insert(arg) +// Set sets or replaces value of given key. +func (a *Args) Set(key ArgKey, value ArgValue) { + key = normalizeKey(key) + a.Remove(key) + a.args.Insert(Arg{key, value}) } // Value extracts value of given key, it returns extracted value and whether the key was found. -func (a *Args) Value(key string) (string, bool) { - _, val, exists := a.find(key) - return val, exists +func (a *Args) Value(key ArgKey) (ArgValue, bool) { + arg, exists := a.find(key) + return arg.value, exists } // Has returns whether given key exists in [Args]. -func (a *Args) Has(key string) bool { - _, _, exists := a.find(key) +func (a *Args) Has(key ArgKey) bool { + _, exists := a.find(key) return exists } // Remove removes given key, it returns the key's value and whether the key was found. -func (a *Args) Remove(key string) (string, bool) { - entry, val, exists := a.find(key) +func (a *Args) Remove(key ArgKey) (ArgValue, bool) { + arg, exists := a.find(key) if exists { - a.args.Delete(entry) + a.args.Delete(arg) } - return val, exists + return arg.value, exists } // SortedList returns ordered list of normalized arguments. func (a *Args) SortedList() []string { - return sets.List(a.args) + args := make([]string, 0, a.args.Len()) + for _, arg := range a.args.UnsortedList() { + args = append(args, arg.String()) + } + slices.Sort(args) + return args } -// find tries to find given key from [Args], and returns whole entry, value and whether the key was found. -func (a *Args) find(key string) (string, string, bool) { - key, prefix := a.keysForSearch(key) - +// find tries to find given key from [Args], and returns whole entry, and whether the key was found. +func (a *Args) find(key ArgKey) (Arg, bool) { + key = normalizeKey(key) for _, arg := range a.args.UnsortedList() { - if key == arg { - return key, "", true - } - - if strings.HasPrefix(arg, prefix) { - val := strings.SplitN(arg, "=", 2)[1] - return arg, val, true + if key == arg.key { + return arg, true } } - - return "", "", false + return Arg{}, false } -// keysForSearch returns whole key and a prefix to search for given key in [Args]. -// First one is the whole key to look for without `=` at the end for option-like arguments without any value. -// Second one is a prefix with `=` at the end for arguments with value. -// -// Arguments are normalized to `--key[=value]` in [ParseArgs], here this function also makes sure -// the returned prefixes have the same prefix format for the given key. -func (a *Args) keysForSearch(key string) (string, string) { - prefix := strings.TrimSuffix(key, "=") - +// normalizeKey normalized given key to have a "--" prefix. +func normalizeKey(key ArgKey) ArgKey { if !strings.HasPrefix(key, "-") { - prefix = "--" + prefix + return "--" + key } - - return prefix, prefix + "=" + return key } diff --git a/pkg/mountpoint/args_test.go b/pkg/mountpoint/args_test.go index 7af819c0..c93d1491 100644 --- a/pkg/mountpoint/args_test.go +++ b/pkg/mountpoint/args_test.go @@ -148,21 +148,22 @@ func TestParsingMountpointArgs(t *testing.T) { } } -func TestInsertingNewArgsToMountpointArgs(t *testing.T) { +func TestInsertingArgsToMountpointArgs(t *testing.T) { testCases := []struct { name string existingArgs []string - newArg string + key string + value string want []string }{ { - name: "new arg", + name: "new option", existingArgs: []string{ "allow-delete", "region us-west-2", "aws-max-attempts 5", }, - newArg: mountpoint.ArgReadOnly, + key: mountpoint.ArgReadOnly, want: []string{ "--allow-delete", "--aws-max-attempts=5", @@ -171,14 +172,14 @@ func TestInsertingNewArgsToMountpointArgs(t *testing.T) { }, }, { - name: "existing arg", + name: "existing option", existingArgs: []string{ "allow-delete", "read-only", "region us-west-2", "aws-max-attempts 5", }, - newArg: mountpoint.ArgReadOnly, + key: mountpoint.ArgReadOnly, want: []string{ "--allow-delete", "--aws-max-attempts=5", @@ -186,11 +187,42 @@ func TestInsertingNewArgsToMountpointArgs(t *testing.T) { "--region=us-west-2", }, }, + { + name: "new arg", + existingArgs: []string{ + "allow-delete", + "aws-max-attempts 5", + }, + key: mountpoint.ArgRegion, + value: "us-west-2", + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--region=us-west-2", + }, + }, + { + name: "existing arg", + existingArgs: []string{ + "allow-delete", + "read-only", + "region us-west-2", + "aws-max-attempts 5", + }, + key: mountpoint.ArgRegion, + value: "us-east-1", + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--read-only", + "--region=us-east-1", + }, + }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { args := mountpoint.ParseArgs(testCase.existingArgs) - args.Insert(testCase.newArg) + args.Set(testCase.key, testCase.value) assert.Equals(t, testCase.want, args.SortedList()) }) } @@ -296,13 +328,13 @@ func TestRemovingAnArgumentFromMountpointArgs(t *testing.T) { }, }, { - name: "existing argument with equal", + name: "existing argument without key prefix", args: []string{ "cache /tmp/s3-cache", "region=us-west-2", "aws-max-attempts=5", }, - argToRemove: mountpoint.ArgAWSMaxAttempts, + argToRemove: "aws-max-attempts", want: "5", exists: true, argsAfter: []string{ @@ -356,7 +388,7 @@ func TestCreatingMountpointArgsFromAlreadyParsedArgs(t *testing.T) { "--cache /tmp/s3-cache", "read-only", }) - args.Insert("--user-agent-prefix=s3-csi-driver/1.11.0 credential-source#pod k8s/v1.30.6-eks-7f9249a") + args.Set("--user-agent-prefix", "s3-csi-driver/1.11.0 credential-source#pod k8s/v1.30.6-eks-7f9249a") want := []string{ "--allow-other", @@ -364,9 +396,8 @@ func TestCreatingMountpointArgsFromAlreadyParsedArgs(t *testing.T) { "--read-only", "--user-agent-prefix=s3-csi-driver/1.11.0 credential-source#pod k8s/v1.30.6-eks-7f9249a", } - assert.Equals(t, want, args.SortedList()) - parsedArgs := mountpoint.ParsedArgs(args.SortedList()) + parsedArgs := mountpoint.ParseArgs(args.SortedList()) assert.Equals(t, want, parsedArgs.SortedList()) } From 1363e3c7c9824e777e2ccafbee0109463cab1044 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Tue, 7 Jan 2025 18:31:54 +0000 Subject: [PATCH 3/9] Use `mountpoint.Args` in `aws-s3-csi-mounter` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Burak Varlı --- cmd/aws-s3-csi-mounter/csimounter/csimounter.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/cmd/aws-s3-csi-mounter/csimounter/csimounter.go b/cmd/aws-s3-csi-mounter/csimounter/csimounter.go index 6b5ba551..84d5afd3 100644 --- a/cmd/aws-s3-csi-mounter/csimounter/csimounter.go +++ b/cmd/aws-s3-csi-mounter/csimounter/csimounter.go @@ -7,10 +7,10 @@ import ( "io/fs" "os" "os/exec" - "slices" "k8s.io/klog/v2" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mountoptions" "github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod" ) @@ -52,20 +52,17 @@ func Run(options Options) (int, error) { return 0, fmt.Errorf("passed file descriptor %d is invalid", mountOptions.Fd) } - args := mountOptions.Args + mountpointArgs := mountpoint.ParseArgs(mountOptions.Args) // By default Mountpoint runs in a detached mode. Here we want to monitor it by relaying its output, // and also we want to wait until it terminates. We're passing `--foreground` to achieve this. - const foreground, foregroundShort = "--foreground", "-f" - if !(slices.Contains(args, foreground) || slices.Contains(args, foregroundShort)) { - args = append(args, foreground) - } + mountpointArgs.Set(mountpoint.ArgForeground, mountpoint.ArgNoValue) - args = append([]string{ + args := append([]string{ mountOptions.BucketName, // We pass FUSE fd using `ExtraFiles`, and each entry becomes as file descriptor 3+i. "/dev/fd/3", - }, args...) + }, mountpointArgs.SortedList()...) cmd := exec.Command(options.MountpointPath, args...) cmd.ExtraFiles = []*os.File{fuseDev} From 90245acf104d26b1f0884214369da01e9bcaec3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Mon, 20 Jan 2025 13:43:05 +0000 Subject: [PATCH 4/9] Remove redundant comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Burak Varlı --- pkg/driver/node/mounter/systemd_mounter.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index c4254106..743de6a8 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -179,7 +179,6 @@ func moveArgumentsToEnv(args mountpoint.Args, env []string) (mountpoint.Args, [] // method to add the user agent prefix to the Mountpoint arguments. // https://github.com/awslabs/mountpoint-s3/pull/548 func addUserAgentToArguments(args mountpoint.Args, userAgent string) mountpoint.Args { - // Remove existing user-agent if provided to ensure we always use the correct user-agent args.Set(mountpoint.ArgUserAgentPrefix, userAgent) return args } From e35a62924c0b845623bfa2f1028c53f0afe607ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Tue, 21 Jan 2025 13:34:50 +0000 Subject: [PATCH 5/9] Remove `addUserAgentToArguments` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Burak Varlı --- pkg/driver/node/mounter/systemd_mounter.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index 743de6a8..59f5ab86 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -149,7 +149,7 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo env = credentials.Env(awsProfile) } args, env = moveArgumentsToEnv(args, env) - args = addUserAgentToArguments(args, UserAgent(authenticationSource, m.kubernetesVersion)) + args.Set(mountpoint.ArgUserAgentPrefix, UserAgent(authenticationSource, m.kubernetesVersion)) output, err := m.Runner.StartService(timeoutCtx, &system.ExecConfig{ Name: "mount-s3-" + m.MpVersion + "-" + uuid.New().String() + ".service", @@ -176,13 +176,6 @@ func moveArgumentsToEnv(args mountpoint.Args, env []string) (mountpoint.Args, [] return args, env } -// method to add the user agent prefix to the Mountpoint arguments. -// https://github.com/awslabs/mountpoint-s3/pull/548 -func addUserAgentToArguments(args mountpoint.Args, userAgent string) mountpoint.Args { - args.Set(mountpoint.ArgUserAgentPrefix, userAgent) - return args -} - func (m *SystemdMounter) Unmount(target string) error { timeoutCtx, cancel := context.WithTimeout(m.Ctx, 30*time.Second) defer cancel() From 35491f7592133dd77f94aa307b201563f92beb2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Tue, 21 Jan 2025 13:42:38 +0000 Subject: [PATCH 6/9] Use `mountpoint.Args` in `CredentialProvider` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Burak Varlı --- .../node/mounter/credential_provider.go | 13 +++---- .../node/mounter/credential_provider_test.go | 35 ++++++++++--------- pkg/driver/node/mounter/systemd_mounter.go | 19 ---------- .../node/mounter/systemd_mounter_test.go | 35 ------------------- pkg/driver/node/node.go | 4 +-- 5 files changed, 27 insertions(+), 79 deletions(-) diff --git a/pkg/driver/node/mounter/credential_provider.go b/pkg/driver/node/mounter/credential_provider.go index 12cc4a82..4b90d0f3 100644 --- a/pkg/driver/node/mounter/credential_provider.go +++ b/pkg/driver/node/mounter/credential_provider.go @@ -22,6 +22,7 @@ import ( k8sstrings "k8s.io/utils/strings" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" ) const hostPluginDirEnv = "HOST_PLUGIN_DIR" @@ -84,7 +85,7 @@ func (c *CredentialProvider) CleanupToken(volumeID string, podID string) error { // Provide provides mount credentials for given volume and volume context. // Depending on the configuration, it either returns driver-level or pod-level credentials. -func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) { +func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeCtx map[string]string, args mountpoint.Args) (*MountCredentials, error) { if volumeCtx == nil { return nil, status.Error(codes.InvalidArgument, "Missing volume context") } @@ -92,7 +93,7 @@ func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volum authenticationSource := volumeCtx[volumecontext.AuthenticationSource] switch authenticationSource { case AuthenticationSourcePod: - return c.provideFromPod(ctx, volumeID, volumeCtx, mountpointArgs) + return c.provideFromPod(ctx, volumeID, volumeCtx, args) case AuthenticationSourceUnspecified, AuthenticationSourceDriver: return c.provideFromDriver() default: @@ -119,7 +120,7 @@ func (c *CredentialProvider) provideFromDriver() (*MountCredentials, error) { }, nil } -func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) { +func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeCtx map[string]string, args mountpoint.Args) (*MountCredentials, error) { klog.V(4).Infof("NodePublishVolume: Using pod identity") tokensJson := volumeCtx[volumecontext.CSIServiceAccountTokens] @@ -144,7 +145,7 @@ func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string return nil, err } - region, err := c.stsRegion(volumeCtx, mountpointArgs) + region, err := c.stsRegion(volumeCtx, args) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "Failed to detect STS AWS Region, please explicitly set the AWS Region, see "+stsConfigDocsPage) } @@ -245,14 +246,14 @@ func (c *CredentialProvider) findPodServiceAccountRole(ctx context.Context, volu // 4. Calling IMDS to detect region // // It returns an error if all of them fails. -func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, mountpointArgs []string) (string, error) { +func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, args mountpoint.Args) (string, error) { region := volumeCtx[volumecontext.STSRegion] if region != "" { klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from volume context", region) return region, nil } - if region, ok := ExtractMountpointArgument(mountpointArgs, mountpointArgRegion); ok { + if region, ok := args.Value(mountpoint.ArgRegion); ok { klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from S3 bucket region", region) return region, nil } diff --git a/pkg/driver/node/mounter/credential_provider_test.go b/pkg/driver/node/mounter/credential_provider_test.go index 26f5a7c9..ac8c23e2 100644 --- a/pkg/driver/node/mounter/credential_provider_test.go +++ b/pkg/driver/node/mounter/credential_provider_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -42,7 +43,7 @@ func TestProvidingDriverLevelCredentials(t *testing.T) { } { provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce) - credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, nil) + credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.AccessKeyID, "test-access-key") @@ -58,7 +59,7 @@ func TestProvidingDriverLevelCredentials(t *testing.T) { func TestProvidingDriverLevelCredentialsWithEmptyEnv(t *testing.T) { provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce) - credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{"authenticationSource": "driver"}, nil) + credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{"authenticationSource": "driver"}, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.AccessKeyID, "") @@ -93,7 +94,7 @@ func TestProvidingPodLevelCredentials(t *testing.T) { Token: "test-service-account-token", }, }), - }, nil) + }, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) // Should disable env variable provider @@ -216,7 +217,7 @@ func TestProvidingPodLevelCredentialsWithMissingInformation(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { - credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, nil) + credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, credentials) if err == nil { t.Error("it should fail with missing information") @@ -252,7 +253,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { return "", errors.New("unknown region") }) - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, credentials) if err == nil { t.Error("it should fail if there is not any region information") @@ -268,7 +269,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { return "us-east-1", nil }) - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "us-east-1") assertEquals(t, credentials.DefaultRegion, "us-east-1") @@ -286,7 +287,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "eu-west-1") assertEquals(t, credentials.DefaultRegion, "eu-west-1") @@ -304,7 +305,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_DEFAULT_REGION", "eu-west-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "eu-west-1") assertEquals(t, credentials.DefaultRegion, "eu-west-1") @@ -323,7 +324,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "eu-west-1") assertEquals(t, credentials.DefaultRegion, "eu-north-1") @@ -341,7 +342,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "us-west-1") assertEquals(t, credentials.DefaultRegion, "us-west-1") @@ -359,7 +360,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--read-only"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--read-only"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "eu-west-1") assertEquals(t, credentials.DefaultRegion, "eu-west-1") @@ -378,7 +379,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { t.Setenv("AWS_REGION", "eu-west-1") t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "us-west-1") assertEquals(t, credentials.DefaultRegion, "eu-north-1") @@ -398,7 +399,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { volumeContext["stsRegion"] = "ap-south-1" - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "ap-south-1") assertEquals(t, credentials.DefaultRegion, "ap-south-1") @@ -419,7 +420,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { volumeContext["stsRegion"] = "ap-south-1" - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) + credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"})) assertEquals(t, nil, err) assertEquals(t, credentials.Region, "ap-south-1") assertEquals(t, credentials.DefaultRegion, "eu-north-1") @@ -457,7 +458,7 @@ func TestProvidingPodLevelCredentialsForDifferentPodsWithDifferentRoles(t *testi Token: "test-service-account-token-1", }, }), - }, nil) + }, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) credentialsPodTwo, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{ @@ -470,7 +471,7 @@ func TestProvidingPodLevelCredentialsForDifferentPodsWithDifferentRoles(t *testi Token: "test-service-account-token-2", }, }), - }, nil) + }, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) // PodOne @@ -526,7 +527,7 @@ func TestProvidingPodLevelCredentialsWithSlashInVolumeID(t *testing.T) { Token: "test-service-account-token", }, }), - }, nil) + }, mountpoint.ParseArgs(nil)) assertEquals(t, nil, err) assertEquals(t, credentials.AccessKeyID, "") diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index 59f5ab86..d73978c7 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "path/filepath" - "strings" "time" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" @@ -200,21 +199,3 @@ func (m *SystemdMounter) Unmount(target string) error { } return nil } - -const ( - mountpointArgRegion = "region" - mountpointArgCache = "cache" -) - -// ExtractMountpointArgument extracts value of a given argument from `mountpointArgs`. -func ExtractMountpointArgument(mountpointArgs []string, argument string) (string, bool) { - // `mountpointArgs` normalized to `--arg=val` in `S3NodeServer.NodePublishVolume`. - prefix := fmt.Sprintf("--%s=", argument) - for _, arg := range mountpointArgs { - if strings.HasPrefix(arg, prefix) { - val := strings.SplitN(arg, "=", 2)[1] - return val, true - } - } - return "", false -} diff --git a/pkg/driver/node/mounter/systemd_mounter_test.go b/pkg/driver/node/mounter/systemd_mounter_test.go index aec26d53..6b3fffe0 100644 --- a/pkg/driver/node/mounter/systemd_mounter_test.go +++ b/pkg/driver/node/mounter/systemd_mounter_test.go @@ -279,41 +279,6 @@ func TestProvidingEnvVariablesForMountpointProcess(t *testing.T) { } } -func TestExtractMountpointArgument(t *testing.T) { - for name, test := range map[string]struct { - input []string - argument string - expectedToFound bool - expectedValue string - }{ - "Extract Existing Argument": { - input: []string{ - "--region=us-east-1", - }, - argument: "region", - expectedToFound: true, - expectedValue: "us-east-1", - }, - "Extract Non Existing Argument": { - input: []string{ - "--bucket=test", - }, - argument: "region", - expectedToFound: false, - }, - "Extract Non Existing Argument With Empty Input": { - argument: "region", - expectedToFound: false, - }, - } { - t.Run(name, func(t *testing.T) { - val, found := mounter.ExtractMountpointArgument(test.input, test.argument) - assertEquals(t, test.expectedToFound, found) - assertEquals(t, test.expectedValue, val) - }) - } -} - func TestIsMountPoint(t *testing.T) { testDir := t.TempDir() mountpointS3MountPath := filepath.Join(testDir, "/var/lib/kubelet/pods/46efe8aa-75d9-4b12-8fdd-0ce0c2cabd99/volumes/kubernetes.io~csi/s3-mp-csi-pv/mount") diff --git a/pkg/driver/node/node.go b/pkg/driver/node/node.go index 4930d5b6..9ee0a01d 100644 --- a/pkg/driver/node/node.go +++ b/pkg/driver/node/node.go @@ -132,13 +132,13 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl args := mountpoint.ParseArgs(mountpointArgs) - credentials, err := ns.credentialProvider.Provide(ctx, req.VolumeId, req.VolumeContext, mountpointArgs) + credentials, err := ns.credentialProvider.Provide(ctx, req.VolumeId, req.VolumeContext, args) if err != nil { klog.Errorf("NodePublishVolume: failed to provide credentials: %v", err) return nil, err } - klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, mountpointArgs) + klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, args.SortedList()) if err := ns.Mounter.Mount(bucket, target, credentials, args); err != nil { os.Remove(target) From 21a0fe0d891ae509f46d60aef81a05c514e179a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Tue, 21 Jan 2025 13:43:58 +0000 Subject: [PATCH 7/9] Clarify comments on `ParseArgs` and set value to `ArgNoValue` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Burak Varlı --- pkg/mountpoint/args.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pkg/mountpoint/args.go b/pkg/mountpoint/args.go index f7bd20da..ecfd5b5e 100644 --- a/pkg/mountpoint/args.go +++ b/pkg/mountpoint/args.go @@ -56,17 +56,18 @@ func ParseArgs(passedArgs []string) Args { parts := strings.SplitN(strings.Trim(arg, " "), "=", 2) if len(parts) == 2 { - // Ex: --key=value, key=value + // Ex: `--key=value` or `key=value` key, value = parts[0], parts[1] } else { - // Ex: --key value, key value - // Ex: --key, key + // Ex: `--key value` or `key value` + // Ex: `--key` or `key` parts = strings.SplitN(strings.Trim(parts[0], " "), " ", 2) if len(parts) == 1 { - // Ex: --key, key + // Ex: `--key` or `key` key = parts[0] + value = ArgNoValue } else { - // Ex: --key value, key value + // Ex: `--key value` or `key value` key, value = parts[0], strings.Trim(parts[1], " ") } } From 41c14188bfda5eacea6c914373d82f9e2e810715 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Tue, 21 Jan 2025 13:45:38 +0000 Subject: [PATCH 8/9] Add test case with spaces before and after the arguments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Burak Varlı --- pkg/mountpoint/args_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pkg/mountpoint/args_test.go b/pkg/mountpoint/args_test.go index c93d1491..d22318d1 100644 --- a/pkg/mountpoint/args_test.go +++ b/pkg/mountpoint/args_test.go @@ -80,6 +80,19 @@ func TestParsingMountpointArgs(t *testing.T) { "--uid=1000", }, }, + { + name: "with spaces before and after", + input: []string{ + "--allow-other ", + " --uid 1000", + " --gid 2000 ", + }, + want: []string{ + "--allow-other", + "--gid=2000", + "--uid=1000", + }, + }, { name: "with single dash prefix", input: []string{ From ecf40036b41cc8986b2eea22a80b345ea58936d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Tue, 21 Jan 2025 13:54:55 +0000 Subject: [PATCH 9/9] Make `mountpoint.Arg` struct private MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Burak Varlı --- pkg/mountpoint/args.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pkg/mountpoint/args.go b/pkg/mountpoint/args.go index ecfd5b5e..4825484b 100644 --- a/pkg/mountpoint/args.go +++ b/pkg/mountpoint/args.go @@ -28,14 +28,14 @@ type ArgValue = string // A value to use in arguments without any value, i.e., an option. const ArgNoValue = "" -// An Arg represents an argument to be passed to Mountpoint. -type Arg struct { +// An arg represents an argument to be passed to Mountpoint. +type arg struct { key ArgKey value ArgValue } // String returns string representation of the argument to pass Mountpoint. -func (a *Arg) String() string { +func (a *arg) String() string { if a.value == ArgNoValue { return a.key } @@ -44,17 +44,17 @@ func (a *Arg) String() string { // An Args represents arguments to be passed to Mountpoint during mount. type Args struct { - args sets.Set[Arg] + args sets.Set[arg] } // ParseArgs parses given list of unnormalized and returns a normalized [Args]. func ParseArgs(passedArgs []string) Args { - args := sets.New[Arg]() + args := sets.New[arg]() - for _, arg := range passedArgs { + for _, a := range passedArgs { var key, value string - parts := strings.SplitN(strings.Trim(arg, " "), "=", 2) + parts := strings.SplitN(strings.Trim(a, " "), "=", 2) if len(parts) == 2 { // Ex: `--key=value` or `key=value` key, value = parts[0], parts[1] @@ -81,7 +81,7 @@ func ParseArgs(passedArgs []string) Args { continue } - args.Insert(Arg{key, value}) + args.Insert(arg{key, value}) } return Args{args} @@ -91,7 +91,7 @@ func ParseArgs(passedArgs []string) Args { func (a *Args) Set(key ArgKey, value ArgValue) { key = normalizeKey(key) a.Remove(key) - a.args.Insert(Arg{key, value}) + a.args.Insert(arg{key, value}) } // Value extracts value of given key, it returns extracted value and whether the key was found. @@ -126,14 +126,14 @@ func (a *Args) SortedList() []string { } // find tries to find given key from [Args], and returns whole entry, and whether the key was found. -func (a *Args) find(key ArgKey) (Arg, bool) { +func (a *Args) find(key ArgKey) (arg, bool) { key = normalizeKey(key) for _, arg := range a.args.UnsortedList() { if key == arg.key { return arg, true } } - return Arg{}, false + return arg{}, false } // normalizeKey normalized given key to have a "--" prefix.