Skip to content

Commit

Permalink
Create mountpoint.Args for parsing and accessing Mountpoint args (#349
Browse files Browse the repository at this point in the history
)

Previously, accessing and manipulating of Mountpoint arguments was
spread over to the codebase. This new mountpoint package contains all
the logic related to Mountpoint arguments.

Splitted out of
#328.

---

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Signed-off-by: Burak Varlı <[email protected]>
  • Loading branch information
unexge authored Jan 22, 2025
1 parent 27337ea commit ef1c705
Show file tree
Hide file tree
Showing 13 changed files with 625 additions and 146 deletions.
13 changes: 5 additions & 8 deletions cmd/aws-s3-csi-mounter/csimounter/csimounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (
"io/fs"
"os"
"os/exec"
"slices"

"k8s.io/klog/v2"

"github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint"
"github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mountoptions"
"github.com/awslabs/aws-s3-csi-driver/pkg/podmounter/mppod"
)
Expand Down Expand Up @@ -52,20 +52,17 @@ func Run(options Options) (int, error) {
return 0, fmt.Errorf("passed file descriptor %d is invalid", mountOptions.Fd)
}

args := mountOptions.Args
mountpointArgs := mountpoint.ParseArgs(mountOptions.Args)

// By default Mountpoint runs in a detached mode. Here we want to monitor it by relaying its output,
// and also we want to wait until it terminates. We're passing `--foreground` to achieve this.
const foreground, foregroundShort = "--foreground", "-f"
if !(slices.Contains(args, foreground) || slices.Contains(args, foregroundShort)) {
args = append(args, foreground)
}
mountpointArgs.Set(mountpoint.ArgForeground, mountpoint.ArgNoValue)

args = append([]string{
args := append([]string{
mountOptions.BucketName,
// We pass FUSE fd using `ExtraFiles`, and each entry becomes as file descriptor 3+i.
"/dev/fd/3",
}, args...)
}, mountpointArgs.SortedList()...)

cmd := exec.Command(options.MountpointPath, args...)
cmd.ExtraFiles = []*os.File{fuseDev}
Expand Down
13 changes: 7 additions & 6 deletions pkg/driver/node/mounter/credential_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
k8sstrings "k8s.io/utils/strings"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext"
"github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint"
)

const hostPluginDirEnv = "HOST_PLUGIN_DIR"
Expand Down Expand Up @@ -84,15 +85,15 @@ func (c *CredentialProvider) CleanupToken(volumeID string, podID string) error {

// Provide provides mount credentials for given volume and volume context.
// Depending on the configuration, it either returns driver-level or pod-level credentials.
func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) {
func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeCtx map[string]string, args mountpoint.Args) (*MountCredentials, error) {
if volumeCtx == nil {
return nil, status.Error(codes.InvalidArgument, "Missing volume context")
}

authenticationSource := volumeCtx[volumecontext.AuthenticationSource]
switch authenticationSource {
case AuthenticationSourcePod:
return c.provideFromPod(ctx, volumeID, volumeCtx, mountpointArgs)
return c.provideFromPod(ctx, volumeID, volumeCtx, args)
case AuthenticationSourceUnspecified, AuthenticationSourceDriver:
return c.provideFromDriver()
default:
Expand All @@ -119,7 +120,7 @@ func (c *CredentialProvider) provideFromDriver() (*MountCredentials, error) {
}, nil
}

func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) {
func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeCtx map[string]string, args mountpoint.Args) (*MountCredentials, error) {
klog.V(4).Infof("NodePublishVolume: Using pod identity")

tokensJson := volumeCtx[volumecontext.CSIServiceAccountTokens]
Expand All @@ -144,7 +145,7 @@ func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string
return nil, err
}

region, err := c.stsRegion(volumeCtx, mountpointArgs)
region, err := c.stsRegion(volumeCtx, args)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "Failed to detect STS AWS Region, please explicitly set the AWS Region, see "+stsConfigDocsPage)
}
Expand Down Expand Up @@ -245,14 +246,14 @@ func (c *CredentialProvider) findPodServiceAccountRole(ctx context.Context, volu
// 4. Calling IMDS to detect region
//
// It returns an error if all of them fails.
func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, mountpointArgs []string) (string, error) {
func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, args mountpoint.Args) (string, error) {
region := volumeCtx[volumecontext.STSRegion]
if region != "" {
klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from volume context", region)
return region, nil
}

if region, ok := ExtractMountpointArgument(mountpointArgs, mountpointArgRegion); ok {
if region, ok := args.Value(mountpoint.ArgRegion); ok {
klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from S3 bucket region", region)
return region, nil
}
Expand Down
35 changes: 18 additions & 17 deletions pkg/driver/node/mounter/credential_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter"
"github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -42,7 +43,7 @@ func TestProvidingDriverLevelCredentials(t *testing.T) {
} {

provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce)
credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, nil)
credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)

assertEquals(t, credentials.AccessKeyID, "test-access-key")
Expand All @@ -58,7 +59,7 @@ func TestProvidingDriverLevelCredentials(t *testing.T) {

func TestProvidingDriverLevelCredentialsWithEmptyEnv(t *testing.T) {
provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce)
credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{"authenticationSource": "driver"}, nil)
credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{"authenticationSource": "driver"}, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)

assertEquals(t, credentials.AccessKeyID, "")
Expand Down Expand Up @@ -93,7 +94,7 @@ func TestProvidingPodLevelCredentials(t *testing.T) {
Token: "test-service-account-token",
},
}),
}, nil)
}, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)

// Should disable env variable provider
Expand Down Expand Up @@ -216,7 +217,7 @@ func TestProvidingPodLevelCredentialsWithMissingInformation(t *testing.T) {
},
} {
t.Run(name, func(t *testing.T) {
credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, nil)
credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, mountpoint.ParseArgs(nil))
assertEquals(t, nil, credentials)
if err == nil {
t.Error("it should fail with missing information")
Expand Down Expand Up @@ -252,7 +253,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {
return "", errors.New("unknown region")
})

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil)
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil))
assertEquals(t, nil, credentials)
if err == nil {
t.Error("it should fail if there is not any region information")
Expand All @@ -268,7 +269,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {
return "us-east-1", nil
})

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil)
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)
assertEquals(t, credentials.Region, "us-east-1")
assertEquals(t, credentials.DefaultRegion, "us-east-1")
Expand All @@ -286,7 +287,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {

t.Setenv("AWS_REGION", "eu-west-1")

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil)
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)
assertEquals(t, credentials.Region, "eu-west-1")
assertEquals(t, credentials.DefaultRegion, "eu-west-1")
Expand All @@ -304,7 +305,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {

t.Setenv("AWS_DEFAULT_REGION", "eu-west-1")

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil)
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)
assertEquals(t, credentials.Region, "eu-west-1")
assertEquals(t, credentials.DefaultRegion, "eu-west-1")
Expand All @@ -323,7 +324,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {
t.Setenv("AWS_REGION", "eu-west-1")
t.Setenv("AWS_DEFAULT_REGION", "eu-north-1")

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil)
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)
assertEquals(t, credentials.Region, "eu-west-1")
assertEquals(t, credentials.DefaultRegion, "eu-north-1")
Expand All @@ -341,7 +342,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {

t.Setenv("AWS_REGION", "eu-west-1")

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"})
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"}))
assertEquals(t, nil, err)
assertEquals(t, credentials.Region, "us-west-1")
assertEquals(t, credentials.DefaultRegion, "us-west-1")
Expand All @@ -359,7 +360,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {

t.Setenv("AWS_REGION", "eu-west-1")

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--read-only"})
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--read-only"}))
assertEquals(t, nil, err)
assertEquals(t, credentials.Region, "eu-west-1")
assertEquals(t, credentials.DefaultRegion, "eu-west-1")
Expand All @@ -378,7 +379,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {
t.Setenv("AWS_REGION", "eu-west-1")
t.Setenv("AWS_DEFAULT_REGION", "eu-north-1")

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"})
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"}))
assertEquals(t, nil, err)
assertEquals(t, credentials.Region, "us-west-1")
assertEquals(t, credentials.DefaultRegion, "eu-north-1")
Expand All @@ -398,7 +399,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {

volumeContext["stsRegion"] = "ap-south-1"

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"})
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"}))
assertEquals(t, nil, err)
assertEquals(t, credentials.Region, "ap-south-1")
assertEquals(t, credentials.DefaultRegion, "ap-south-1")
Expand All @@ -419,7 +420,7 @@ func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) {

volumeContext["stsRegion"] = "ap-south-1"

credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"})
credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, mountpoint.ParseArgs([]string{"--region=us-west-1"}))
assertEquals(t, nil, err)
assertEquals(t, credentials.Region, "ap-south-1")
assertEquals(t, credentials.DefaultRegion, "eu-north-1")
Expand Down Expand Up @@ -457,7 +458,7 @@ func TestProvidingPodLevelCredentialsForDifferentPodsWithDifferentRoles(t *testi
Token: "test-service-account-token-1",
},
}),
}, nil)
}, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)

credentialsPodTwo, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{
Expand All @@ -470,7 +471,7 @@ func TestProvidingPodLevelCredentialsForDifferentPodsWithDifferentRoles(t *testi
Token: "test-service-account-token-2",
},
}),
}, nil)
}, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)

// PodOne
Expand Down Expand Up @@ -526,7 +527,7 @@ func TestProvidingPodLevelCredentialsWithSlashInVolumeID(t *testing.T) {
Token: "test-service-account-token",
},
}),
}, nil)
}, mountpoint.ParseArgs(nil))
assertEquals(t, nil, err)

assertEquals(t, credentials.AccessKeyID, "")
Expand Down
4 changes: 3 additions & 1 deletion pkg/driver/node/mounter/fake_mounter.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package mounter

import "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint"

type FakeMounter struct{}

func (m *FakeMounter) Mount(bucketName string, target string,
credentials *MountCredentials, options []string) error {
credentials *MountCredentials, args mountpoint.Args) error {
return nil
}

Expand Down
9 changes: 5 additions & 4 deletions pkg/driver/node/mounter/mocks/mock_mount.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pkg/driver/node/mounter/mount_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ const (
MountpointCacheKey = "UNSTABLE_MOUNTPOINT_CACHE_KEY"
defaultMountS3Path = "/usr/bin/mount-s3"
userAgentPrefix = "--user-agent-prefix"
awsMaxAttemptsOption = "--aws-max-attempts"
)

type MountCredentials struct {
Expand Down
3 changes: 2 additions & 1 deletion pkg/driver/node/mounter/mounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"os"

"github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint"
"github.com/awslabs/aws-s3-csi-driver/pkg/system"
)

Expand All @@ -15,7 +16,7 @@ type ServiceRunner interface {

// Mounter is an interface for mount operations
type Mounter interface {
Mount(bucketName string, target string, credentials *MountCredentials, options []string) error
Mount(bucketName string, target string, credentials *MountCredentials, args mountpoint.Args) error
Unmount(target string) error
IsMountPoint(target string) (bool, error)
}
Expand Down
Loading

0 comments on commit ef1c705

Please sign in to comment.