diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 86654b9e..b9a09694 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -21,22 +21,22 @@ import ( "fmt" "net" "os" - "time" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" - "github.com/awslabs/aws-s3-csi-driver/pkg/util" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/klog/v2" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/regionprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/version" ) const ( - driverName = "s3.csi.aws.com" - webIdentityTokenEnv = "AWS_WEB_IDENTITY_TOKEN_FILE" + driverName = "s3.csi.aws.com" grpcServerMaxReceiveMessageSize = 1024 * 1024 * 2 // 2MB @@ -74,13 +74,14 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error klog.Infof("Driver version: %v, Git commit: %v, build date: %v, nodeID: %v, mount-s3 version: %v, kubernetes version: %v", version.DriverVersion, version.GitCommit, version.BuildDate, nodeID, mpVersion, kubernetesVersion) - systemd_mounter, err := mounter.NewSystemdMounter(mpVersion, kubernetesVersion) + systemdMounter, err := mounter.NewSystemdMounter(mpVersion) if err != nil { klog.Fatalln(err) } - credentialProvider := mounter.NewCredentialProvider(clientset.CoreV1(), containerPluginDir, mounter.RegionFromIMDSOnce) - nodeServer := node.NewS3NodeServer(nodeID, systemd_mounter, credentialProvider) + credentialProvider := credentialprovider.New(clientset.CoreV1()) + regionProvider := regionprovider.New(regionprovider.RegionFromIMDSOnce) + nodeServer := node.NewS3NodeServer(nodeID, systemdMounter, credentialProvider, regionProvider, kubernetesVersion) return &Driver{ Endpoint: endpoint, @@ -90,14 +91,6 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error } func (d *Driver) Run() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - tokenFile := os.Getenv(webIdentityTokenEnv) - if tokenFile != "" { - klog.Infof("Found AWS_WEB_IDENTITY_TOKEN_FILE, syncing token") - go tokenFileTender(ctx, tokenFile, "/csi/token") - } - scheme, addr, err := ParseEndpoint(d.Endpoint) if err != nil { return err @@ -150,22 +143,6 @@ func (d *Driver) Stop() { d.Srv.Stop() } -func tokenFileTender(ctx context.Context, sourcePath string, destPath string) { - for { - timer := time.After(10 * time.Second) - err := util.ReplaceFile(destPath, sourcePath, 0600) - if err != nil { - klog.Infof("Failed to sync AWS web token file: %v", err) - } - select { - case <-timer: - continue - case <-ctx.Done(): - return - } - } -} - func kubernetesVersion(clientset *kubernetes.Clientset) (string, error) { version, err := clientset.ServerVersion() if err != nil { diff --git a/pkg/driver/node/awsprofile/aws_profile_test.go b/pkg/driver/node/awsprofile/aws_profile_test.go deleted file mode 100644 index d64a9f0c..00000000 --- a/pkg/driver/node/awsprofile/aws_profile_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package awsprofile_test - -import ( - "context" - "errors" - "io/fs" - "os" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" -) - -const testAccessKeyId = "test-access-key-id" -const testSecretAccessKey = "test-secret-access-key" -const testSessionToken = "test-session-token" - -func TestCreatingAWSProfile(t *testing.T) { - t.Run("create config and credentials files", func(t *testing.T) { - profile, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, testSessionToken) - assertNoError(t, err) - assertCredentialsFromAWSProfile(t, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) - }) - - t.Run("create config and credentials files with empty session token", func(t *testing.T) { - profile, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, "") - assertNoError(t, err) - assertCredentialsFromAWSProfile(t, profile, testAccessKeyId, testSecretAccessKey, "") - }) - - t.Run("ensure config and credentials files are owner readable only", func(t *testing.T) { - profile, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, testSessionToken) - assertNoError(t, err) - assertCredentialsFromAWSProfile(t, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) - - configStat, err := os.Stat(profile.ConfigPath) - assertNoError(t, err) - assertEquals(t, 0400, configStat.Mode()) - - credentialsStat, err := os.Stat(profile.CredentialsPath) - assertNoError(t, err) - assertEquals(t, 0400, credentialsStat.Mode()) - }) - - t.Run("fail if credentials contains non-ascii characters", func(t *testing.T) { - t.Run("access key ID", func(t *testing.T) { - _, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId+"\n\t\r credential_process=exit", testSecretAccessKey, testSessionToken) - assertEquals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) - }) - t.Run("secret access key", func(t *testing.T) { - _, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey+"\n", testSessionToken) - assertEquals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) - }) - t.Run("session token", func(t *testing.T) { - _, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, testSessionToken+"\n\r") - assertEquals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) - }) - }) -} - -func TestCleaningUpAWSProfile(t *testing.T) { - t.Run("clean config and credentials files", func(t *testing.T) { - basepath := t.TempDir() - - profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, testSessionToken) - assertNoError(t, err) - assertCredentialsFromAWSProfile(t, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) - - err = awsprofile.CleanupAWSProfile(basepath) - assertNoError(t, err) - - _, err = os.Stat(profile.ConfigPath) - assertEquals(t, true, errors.Is(err, fs.ErrNotExist)) - - _, err = os.Stat(profile.CredentialsPath) - assertEquals(t, true, errors.Is(err, fs.ErrNotExist)) - }) - - t.Run("cleaning non-existent config and credentials files should not be an error", func(t *testing.T) { - err := awsprofile.CleanupAWSProfile(t.TempDir()) - assertNoError(t, err) - }) -} - -func assertCredentialsFromAWSProfile(t *testing.T, profile awsprofile.AWSProfile, accessKeyID string, secretAccessKey string, sessionToken string) { - credentials := parseAWSProfile(t, profile) - assertEquals(t, accessKeyID, credentials.AccessKeyID) - assertEquals(t, secretAccessKey, credentials.SecretAccessKey) - assertEquals(t, sessionToken, credentials.SessionToken) -} - -func parseAWSProfile(t *testing.T, profile awsprofile.AWSProfile) aws.Credentials { - sharedConfig, err := config.LoadSharedConfigProfile(context.Background(), profile.Name, func(c *config.LoadSharedConfigOptions) { - c.ConfigFiles = []string{profile.ConfigPath} - c.CredentialsFiles = []string{profile.CredentialsPath} - }) - assertNoError(t, err) - return sharedConfig.Credentials -} - -func assertEquals[T comparable](t *testing.T, expected T, got T) { - if expected != got { - t.Errorf("Expected %#v, Got %#v", expected, got) - } -} - -func assertNoError(t *testing.T, err error) { - if err != nil { - t.Errorf("Expected no error, but got: %s", err) - } -} diff --git a/pkg/driver/node/awsprofile/aws_profile.go b/pkg/driver/node/credentialprovider/awsprofile/aws_profile.go similarity index 81% rename from pkg/driver/node/awsprofile/aws_profile.go rename to pkg/driver/node/credentialprovider/awsprofile/aws_profile.go index 8ecd4f78..2e48993b 100644 --- a/pkg/driver/node/awsprofile/aws_profile.go +++ b/pkg/driver/node/credentialprovider/awsprofile/aws_profile.go @@ -9,13 +9,14 @@ import ( "path/filepath" "strings" "unicode" + + "github.com/google/renameio" ) const ( awsProfileName = "s3-csi" awsProfileConfigFilename = "s3-csi-config" awsProfileCredentialsFilename = "s3-csi-credentials" - awsProfileFilePerm = fs.FileMode(0400) // only owner readable ) // ErrInvalidCredentials is returned when given AWS Credentials contains invalid characters. @@ -23,14 +24,14 @@ var ErrInvalidCredentials = errors.New("aws-profile: Invalid AWS Credentials") // An AWSProfile represents an AWS profile with it's credentials and config files. type AWSProfile struct { - Name string - ConfigPath string - CredentialsPath string + Name string + ConfigFilename string + CredentialsFilename string } // CreateAWSProfile creates an AWS Profile with credentials and config files from given credentials. // Created credentials and config files can be clean up with `CleanupAWSProfile`. -func CreateAWSProfile(basepath string, accessKeyID string, secretAccessKey string, sessionToken string) (AWSProfile, error) { +func CreateAWSProfile(basepath string, accessKeyID string, secretAccessKey string, sessionToken string, filePerm fs.FileMode) (AWSProfile, error) { if !isValidCredential(accessKeyID) || !isValidCredential(secretAccessKey) || !isValidCredential(sessionToken) { return AWSProfile{}, ErrInvalidCredentials } @@ -38,21 +39,21 @@ func CreateAWSProfile(basepath string, accessKeyID string, secretAccessKey strin name := awsProfileName configPath := filepath.Join(basepath, awsProfileConfigFilename) - err := writeAWSProfileFile(configPath, configFileContents(name)) + err := writeAWSProfileFile(configPath, configFileContents(name), filePerm) if err != nil { return AWSProfile{}, fmt.Errorf("aws-profile: Failed to create config file %s: %v", configPath, err) } credentialsPath := filepath.Join(basepath, awsProfileCredentialsFilename) - err = writeAWSProfileFile(credentialsPath, credentialsFileContents(name, accessKeyID, secretAccessKey, sessionToken)) + err = writeAWSProfileFile(credentialsPath, credentialsFileContents(name, accessKeyID, secretAccessKey, sessionToken), filePerm) if err != nil { return AWSProfile{}, fmt.Errorf("aws-profile: Failed to create credentials file %s: %v", credentialsPath, err) } return AWSProfile{ - Name: name, - ConfigPath: configPath, - CredentialsPath: credentialsPath, + Name: name, + ConfigFilename: awsProfileConfigFilename, + CredentialsFilename: awsProfileCredentialsFilename, }, nil } @@ -75,14 +76,8 @@ func CleanupAWSProfile(basepath string) error { return nil } -func writeAWSProfileFile(path string, content string) error { - err := os.WriteFile(path, []byte(content), awsProfileFilePerm) - if err != nil { - return err - } - // If the given file exists, `os.WriteFile` just truncates it without changing it's permissions, - // so we need to ensure it always has the correct permissions. - return os.Chmod(path, awsProfileFilePerm) +func writeAWSProfileFile(path string, content string, filePerm os.FileMode) error { + return renameio.WriteFile(path, []byte(content), filePerm) } func credentialsFileContents(profile string, accessKeyID string, secretAccessKey string, sessionToken string) string { diff --git a/pkg/driver/node/credentialprovider/awsprofile/aws_profile_test.go b/pkg/driver/node/credentialprovider/awsprofile/aws_profile_test.go new file mode 100644 index 00000000..ca54ff29 --- /dev/null +++ b/pkg/driver/node/credentialprovider/awsprofile/aws_profile_test.go @@ -0,0 +1,100 @@ +package awsprofile_test + +import ( + "errors" + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +const testAccessKeyId = "test-access-key-id" +const testSecretAccessKey = "test-secret-access-key" +const testSessionToken = "test-session-token" +const testFilePerm = fs.FileMode(0600) + +func TestCreatingAWSProfile(t *testing.T) { + t.Run("create config and credentials files", func(t *testing.T) { + basepath := t.TempDir() + profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, testSessionToken, testFilePerm) + assert.NoError(t, err) + assertCredentialsFromAWSProfile(t, basepath, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) + }) + + t.Run("create config and credentials files with empty session token", func(t *testing.T) { + basepath := t.TempDir() + profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, "", testFilePerm) + assert.NoError(t, err) + assertCredentialsFromAWSProfile(t, basepath, profile, testAccessKeyId, testSecretAccessKey, "") + }) + + t.Run("ensure config and credentials files are created with correct permissions", func(t *testing.T) { + basepath := t.TempDir() + profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, testSessionToken, testFilePerm) + assert.NoError(t, err) + assertCredentialsFromAWSProfile(t, basepath, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) + + configStat, err := os.Stat(filepath.Join(basepath, profile.ConfigFilename)) + assert.NoError(t, err) + assert.Equals(t, testFilePerm, configStat.Mode()) + + credentialsStat, err := os.Stat(filepath.Join(basepath, profile.CredentialsFilename)) + assert.NoError(t, err) + assert.Equals(t, testFilePerm, credentialsStat.Mode()) + }) + + t.Run("fail if credentials contains non-ascii characters", func(t *testing.T) { + t.Run("access key ID", func(t *testing.T) { + _, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId+"\n\t\r credential_process=exit", testSecretAccessKey, testSessionToken, testFilePerm) + assert.Equals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) + }) + t.Run("secret access key", func(t *testing.T) { + _, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey+"\n", testSessionToken, testFilePerm) + assert.Equals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) + }) + t.Run("session token", func(t *testing.T) { + _, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, testSessionToken+"\n\r", testFilePerm) + assert.Equals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials)) + }) + }) +} + +func TestCleaningUpAWSProfile(t *testing.T) { + t.Run("clean config and credentials files", func(t *testing.T) { + basepath := t.TempDir() + + profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, testSessionToken, testFilePerm) + assert.NoError(t, err) + assertCredentialsFromAWSProfile(t, basepath, profile, testAccessKeyId, testSecretAccessKey, testSessionToken) + + err = awsprofile.CleanupAWSProfile(basepath) + assert.NoError(t, err) + + _, err = os.Stat(filepath.Join(basepath, profile.ConfigFilename)) + assert.Equals(t, true, errors.Is(err, fs.ErrNotExist)) + + _, err = os.Stat(filepath.Join(basepath, profile.CredentialsFilename)) + assert.Equals(t, true, errors.Is(err, fs.ErrNotExist)) + }) + + t.Run("cleaning non-existent config and credentials files should not be an error", func(t *testing.T) { + err := awsprofile.CleanupAWSProfile(t.TempDir()) + assert.NoError(t, err) + }) +} + +func assertCredentialsFromAWSProfile(t *testing.T, basepath string, profile awsprofile.AWSProfile, accessKeyID string, secretAccessKey string, sessionToken string) { + awsprofiletest.AssertCredentialsFromAWSProfile( + t, + profile.Name, + filepath.Join(basepath, profile.ConfigFilename), + filepath.Join(basepath, profile.CredentialsFilename), + accessKeyID, + secretAccessKey, + sessionToken, + ) +} diff --git a/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest/aws_profile.go b/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest/aws_profile.go new file mode 100644 index 00000000..5670eb6b --- /dev/null +++ b/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest/aws_profile.go @@ -0,0 +1,30 @@ +// Package awsprofiletest provides testing utilities for AWS Profiles. +package awsprofiletest + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +func AssertCredentialsFromAWSProfile(t *testing.T, profileName, configFile, credentialsFile, accessKeyID, secretAccessKey, sessionToken string) { + t.Helper() + + credentials := parseAWSProfile(t, profileName, configFile, credentialsFile) + assert.Equals(t, accessKeyID, credentials.AccessKeyID) + assert.Equals(t, secretAccessKey, credentials.SecretAccessKey) + assert.Equals(t, sessionToken, credentials.SessionToken) +} + +func parseAWSProfile(t *testing.T, profileName, configFile, credentialsFile string) aws.Credentials { + sharedConfig, err := config.LoadSharedConfigProfile(context.Background(), profileName, func(c *config.LoadSharedConfigOptions) { + c.ConfigFiles = []string{configFile} + c.CredentialsFiles = []string{credentialsFile} + }) + assert.NoError(t, err) + return sharedConfig.Credentials +} diff --git a/pkg/driver/node/credentialprovider/credentials.go b/pkg/driver/node/credentialprovider/credentials.go new file mode 100644 index 00000000..92b34d4b --- /dev/null +++ b/pkg/driver/node/credentialprovider/credentials.go @@ -0,0 +1,28 @@ +package credentialprovider + +import ( + "io/fs" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" +) + +// CredentialFilePerm is the default permissions to be used for credential files. +// It's only readable and writeable by the owner. +const CredentialFilePerm = fs.FileMode(0600) + +// CredentialDirPerm is the default permissions to be used for credential directories. +// It's only readable, listable (execute bit), and writeable by the owner. +const CredentialDirPerm = fs.FileMode(0700) + +// Credentials is the interface implemented by credential providers. +type Credentials interface { + // Source returns the source of these credentials. + Source() AuthenticationSource + + // Dump dumps credentials into `writePath` and returns environment variables + // relative to `envPath` to pass to Mountpoint during mount. + // + // The environment variables will only passed to Mountpoint once during mount operation, + // in subsequent calls, this method will update previously written credentials on disk. + Dump(writePath string, envPath string) (envprovider.Environment, error) +} diff --git a/pkg/driver/node/credentialprovider/credentials_long_term.go b/pkg/driver/node/credentialprovider/credentials_long_term.go new file mode 100644 index 00000000..8cdafe8a --- /dev/null +++ b/pkg/driver/node/credentialprovider/credentials_long_term.go @@ -0,0 +1,38 @@ +package credentialprovider + +import ( + "fmt" + "path/filepath" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" +) + +type longTermCredentials struct { + source AuthenticationSource + + accessKeyID string + secretAccessKey string + sessionToken string +} + +func (c *longTermCredentials) Source() AuthenticationSource { + return c.source +} + +func (c *longTermCredentials) Dump(writePath string, envPath string) (envprovider.Environment, error) { + awsProfile, err := awsprofile.CreateAWSProfile(writePath, c.accessKeyID, c.secretAccessKey, c.sessionToken, CredentialFilePerm) + if err != nil { + return nil, fmt.Errorf("credentialprovider: long-term: failed to create aws profile: %w", err) + } + + profile := awsProfile.Name + configFile := filepath.Join(envPath, awsProfile.ConfigFilename) + credentialsFile := filepath.Join(envPath, awsProfile.CredentialsFilename) + + return envprovider.Environment{ + envprovider.Format(envprovider.EnvProfile, profile), + envprovider.Format(envprovider.EnvConfigFile, configFile), + envprovider.Format(envprovider.EnvSharedCredentialsFile, credentialsFile), + }, nil +} diff --git a/pkg/driver/node/credentialprovider/credentials_multi.go b/pkg/driver/node/credentialprovider/credentials_multi.go new file mode 100644 index 00000000..b89a280a --- /dev/null +++ b/pkg/driver/node/credentialprovider/credentials_multi.go @@ -0,0 +1,24 @@ +package credentialprovider + +import "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + +type multiCredentials struct { + source AuthenticationSource + credentials []Credentials +} + +func (c *multiCredentials) Source() AuthenticationSource { + return c.source +} + +func (c *multiCredentials) Dump(writePath string, envPath string) (envprovider.Environment, error) { + environment := envprovider.Environment{} + for _, c := range c.credentials { + env, err := c.Dump(writePath, envPath) + if err != nil { + return nil, err + } + environment = append(environment, env...) + } + return environment, nil +} diff --git a/pkg/driver/node/credentialprovider/credentials_shared_profile.go b/pkg/driver/node/credentialprovider/credentials_shared_profile.go new file mode 100644 index 00000000..69fbac17 --- /dev/null +++ b/pkg/driver/node/credentialprovider/credentials_shared_profile.go @@ -0,0 +1,23 @@ +package credentialprovider + +import ( + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" +) + +type sharedProfileCredentials struct { + source AuthenticationSource + + configFile string + sharedCredentialsFile string +} + +func (c *sharedProfileCredentials) Source() AuthenticationSource { + return c.source +} + +func (c *sharedProfileCredentials) Dump(writePath string, envPath string) (envprovider.Environment, error) { + return envprovider.Environment{ + envprovider.Format(envprovider.EnvConfigFile, c.configFile), + envprovider.Format(envprovider.EnvSharedCredentialsFile, c.sharedCredentialsFile), + }, nil +} diff --git a/pkg/driver/node/credentialprovider/credentials_sts_web_identity.go b/pkg/driver/node/credentialprovider/credentials_sts_web_identity.go new file mode 100644 index 00000000..8b3cd9e8 --- /dev/null +++ b/pkg/driver/node/credentialprovider/credentials_sts_web_identity.go @@ -0,0 +1,60 @@ +package credentialprovider + +import ( + "path/filepath" + + "github.com/google/renameio" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/util" +) + +const ( + serviceAccountFilename = "serviceaccount.token" +) + +type stsWebIdentityCredentials struct { + source AuthenticationSource + + cacheKey string + + roleARN string + + // These are mutually exclusive, if both set, `webIdentityToken` will be used. + webIdentityToken string + webIdentityTokenFile string +} + +func (c *stsWebIdentityCredentials) Source() AuthenticationSource { + return c.source +} + +func (c *stsWebIdentityCredentials) Dump(writePath string, envPath string) (envprovider.Environment, error) { + env := envprovider.Environment{ + envprovider.Format(envprovider.EnvRoleARN, c.roleARN), + envprovider.Format(envprovider.EnvWebIdentityTokenFile, filepath.Join(envPath, serviceAccountFilename)), + } + + tokenPath := filepath.Join(writePath, serviceAccountFilename) + + var err error + if c.webIdentityToken != "" { + err = renameio.WriteFile(tokenPath, []byte(c.webIdentityToken), CredentialFilePerm) + } else { + err = util.ReplaceFile(tokenPath, c.webIdentityTokenFile, CredentialFilePerm) + } + if err != nil { + return nil, err + } + + // TODO: These were needed with `systemd` but probably won't be necessary with containerization - except disabling IMDS provider probably. + if c.source == AuthenticationSourcePod { + env = append(env, + envprovider.Format(envprovider.EnvMountpointCacheKey, c.cacheKey), + envprovider.Format(envprovider.EnvConfigFile, filepath.Join(envPath, "disable-config")), + envprovider.Format(envprovider.EnvSharedCredentialsFile, filepath.Join(envPath, "disable-credentials")), + envprovider.Format(envprovider.EnvEC2MetadataDisabled, "true")) + } + + return env, nil +} diff --git a/pkg/driver/node/credentialprovider/provider.go b/pkg/driver/node/credentialprovider/provider.go new file mode 100644 index 00000000..b7795bfc --- /dev/null +++ b/pkg/driver/node/credentialprovider/provider.go @@ -0,0 +1,190 @@ +// Package credentialprovider provides utilities for obtaining AWS credentials to use. +// Depending on the configuration, it either uses Pod-level or Driver-level credentials. +package credentialprovider + +import ( + "context" + "encoding/json" + "fmt" + "os" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8sv1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/klog/v2" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" +) + +// An AuthenticationSource represents the source where the credentials was obtained. +type AuthenticationSource = string + +const ( + // This is when users don't provide a `authenticationSource` option in their volume attributes. + // We're defaulting to `driver` in this case. + AuthenticationSourceUnspecified AuthenticationSource = "" + AuthenticationSourceDriver AuthenticationSource = "driver" + AuthenticationSourcePod AuthenticationSource = "pod" +) + +const ( + envAccessKeyID = "AWS_ACCESS_KEY_ID" + envSecretAccessKey = "AWS_SECRET_ACCESS_KEY" + envSessionToken = "AWS_SESSION_TOKEN" + envConfigFile = "AWS_CONFIG_FILE" + envSharedCredentialsFile = "AWS_SHARED_CREDENTIALS_FILE" + envRoleARN = "AWS_ROLE_ARN" + envWebIdentityTokenFile = "AWS_WEB_IDENTITY_TOKEN_FILE" +) + +const ( + serviceAccountTokenAudienceSTS = "sts.amazonaws.com" + + serviceAccountRoleAnnotation = "eks.amazonaws.com/role-arn" +) + +const podLevelCredentialsDocsPage = "https://github.com/awslabs/mountpoint-s3-csi-driver/blob/main/docs/CONFIGURATION.md#pod-level-credentials" + +type serviceAccountToken struct { + Token string `json:"token"` + ExpirationTimestamp time.Time `json:"expirationTimestamp"` +} + +// A Provider provides methods for accessing AWS credentials. +type Provider struct { + client k8sv1.CoreV1Interface +} + +// New creates a new [Provider] with given client. +func New(client k8sv1.CoreV1Interface) *Provider { + return &Provider{client} +} + +// Provide provides credentials for given volume context. +// Depending on the configuration, it either returns driver-level or pod-level credentials. +func (c *Provider) Provide(ctx context.Context, volumeContext map[string]string) (Credentials, error) { + if volumeContext == nil { + return nil, status.Error(codes.InvalidArgument, "Missing volume context") + } + + authenticationSource := volumeContext[volumecontext.AuthenticationSource] + switch authenticationSource { + case AuthenticationSourcePod: + return c.provideFromPod(ctx, volumeContext) + case AuthenticationSourceUnspecified, AuthenticationSourceDriver: + return c.provideFromDriver() + default: + return nil, fmt.Errorf("unknown `authenticationSource`: %s, only `driver` (default option if not specified) and `pod` supported", authenticationSource) + } +} + +// provideFromDriver provides driver-level AWS credentials. +func (c *Provider) provideFromDriver() (Credentials, error) { + klog.V(4).Infof("credentialprovider: Using driver identity") + + source := AuthenticationSourceDriver + var credentials []Credentials + + // Long-term AWS credentials + accessKeyID := os.Getenv(envAccessKeyID) + secretAccessKey := os.Getenv(envSecretAccessKey) + if accessKeyID != "" && secretAccessKey != "" { + sessionToken := os.Getenv(envSessionToken) + credentials = append(credentials, &longTermCredentials{ + source, + accessKeyID, + secretAccessKey, + sessionToken, + }) + } else { + // Profile provider + // TODO: This is not officially supported and won't work by default with containerization. Should we remove it? + configFile := os.Getenv(envConfigFile) + sharedCredentialsFile := os.Getenv(envSharedCredentialsFile) + if configFile != "" && sharedCredentialsFile != "" { + credentials = append(credentials, &sharedProfileCredentials{ + source, + configFile, + sharedCredentialsFile, + }) + } + } + + // STS Web Identity provider + webIdentityTokenFile := os.Getenv(envWebIdentityTokenFile) + roleARN := os.Getenv(envRoleARN) + if webIdentityTokenFile != "" && roleARN != "" { + credentials = append(credentials, &stsWebIdentityCredentials{ + source: source, + roleARN: roleARN, + webIdentityTokenFile: webIdentityTokenFile, + }) + } + + // Here we don't return an error even `credentials` are empty, because there might be Instance Profile Role + // configured and Mountpoint/CRT would fallback to that if we just return empty credentials/environment-variables. + return &multiCredentials{source: source, credentials: credentials}, nil +} + +// provideFromPod provides pod-level AWS credentials. +func (c *Provider) provideFromPod(ctx context.Context, volumeContext map[string]string) (Credentials, error) { + klog.V(4).Infof("credentialprovider: Using pod identity") + + tokensJson := volumeContext[volumecontext.CSIServiceAccountTokens] + if tokensJson == "" { + klog.Error("`authenticationSource` configured to `pod` but no service account tokens are received. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage) + return nil, status.Error(codes.InvalidArgument, "Missing service account tokens. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage) + } + + var tokens map[string]*serviceAccountToken + if err := json.Unmarshal([]byte(tokensJson), &tokens); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "Failed to parse service account tokens: %v", err) + } + + stsToken := tokens[serviceAccountTokenAudienceSTS] + if stsToken == nil { + klog.Errorf("`authenticationSource` configured to `pod` but no service account tokens for %s received. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage, serviceAccountTokenAudienceSTS) + return nil, status.Errorf(codes.InvalidArgument, "Missing service account token for %s", serviceAccountTokenAudienceSTS) + } + + roleARN, err := c.findPodServiceAccountRole(ctx, volumeContext) + if err != nil { + return nil, err + } + + podNamespace := volumeContext[volumecontext.CSIPodNamespace] + podServiceAccount := volumeContext[volumecontext.CSIServiceAccountName] + cacheKey := podNamespace + "/" + podServiceAccount + + return &stsWebIdentityCredentials{ + source: AuthenticationSourcePod, + webIdentityToken: stsToken.Token, + roleARN: roleARN, + cacheKey: cacheKey, + }, nil +} + +// findPodServiceAccountRole tries to provide associated AWS IAM role for service account specified in the volume context. +func (c *Provider) findPodServiceAccountRole(ctx context.Context, volumeContext map[string]string) (string, error) { + podNamespace := volumeContext[volumecontext.CSIPodNamespace] + podServiceAccount := volumeContext[volumecontext.CSIServiceAccountName] + if podNamespace == "" || podServiceAccount == "" { + klog.Error("`authenticationSource` configured to `pod` but no pod info found. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage) + return "", status.Error(codes.InvalidArgument, "Missing Pod info. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage) + } + + response, err := c.client.ServiceAccounts(podNamespace).Get(ctx, podServiceAccount, metav1.GetOptions{}) + if err != nil { + return "", status.Errorf(codes.InvalidArgument, "Failed to get pod's service account %s/%s: %v", podNamespace, podServiceAccount, err) + } + + roleArn := response.Annotations[serviceAccountRoleAnnotation] + if roleArn == "" { + klog.Error("`authenticationSource` configured to `pod` but pod's service account is not annotated with a role, see " + podLevelCredentialsDocsPage) + return "", status.Errorf(codes.InvalidArgument, "Missing role annotation on pod's service account %s/%s", podNamespace, podServiceAccount) + } + + return roleArn, nil +} diff --git a/pkg/driver/node/credentialprovider/provider_test.go b/pkg/driver/node/credentialprovider/provider_test.go new file mode 100644 index 00000000..f49ab7ca --- /dev/null +++ b/pkg/driver/node/credentialprovider/provider_test.go @@ -0,0 +1,302 @@ +package credentialprovider_test + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +const testAccessKeyID = "test-access-key-id" +const testSecretAccessKey = "test-secret-access-key" +const testSessionToken = "test-session-token" + +const testRoleARN = "arn:aws:iam::111122223333:role/pod-a-role" +const testWebIdentityToken = "test-web-identity-token" + +const testEnvPath = "/test-env" + +func TestProvidingDriverLevelCredentials(t *testing.T) { + volumeContextVariants := []map[string]string{ + { + volumecontext.AuthenticationSource: credentialprovider.AuthenticationSourceDriver, + }, + // It should default to driver-level identity if `authenticationSource` is not passed + { + volumecontext.AuthenticationSource: credentialprovider.AuthenticationSourceUnspecified, + }, + {}, + } + + t.Run("only long-term credentials", func(t *testing.T) { + for _, volCtx := range volumeContextVariants { + setEnvForLongTermCredentials(t) + writePath := t.TempDir() + + provider := credentialprovider.New(nil) + credentials, err := provider.Provide(context.Background(), volCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, credentials.Source()) + + env, err := credentials.Dump(writePath, testEnvPath) + assert.NoError(t, err) + + assert.Equals(t, envprovider.Environment{ + "AWS_PROFILE=s3-csi", + "AWS_CONFIG_FILE=/test-env/s3-csi-config", + "AWS_SHARED_CREDENTIALS_FILE=/test-env/s3-csi-credentials", + }, env) + assertLongTermCredentials(t, writePath) + } + }) + + t.Run("only sts web identity credentials", func(t *testing.T) { + for _, volCtx := range volumeContextVariants { + setEnvForStsWebIdentityCredentials(t) + writePath := t.TempDir() + + provider := credentialprovider.New(nil) + credentials, err := provider.Provide(context.Background(), volCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, credentials.Source()) + + env, err := credentials.Dump(writePath, testEnvPath) + assert.NoError(t, err) + + assert.Equals(t, envprovider.Environment{ + fmt.Sprintf("AWS_ROLE_ARN=%s", testRoleARN), + "AWS_WEB_IDENTITY_TOKEN_FILE=/test-env/serviceaccount.token", + }, env) + assertWebIdentityTokenFile(t, writePath) + } + }) + + t.Run("both long-term and sts web identity credentials", func(t *testing.T) { + for _, volCtx := range volumeContextVariants { + setEnvForLongTermCredentials(t) + setEnvForStsWebIdentityCredentials(t) + writePath := t.TempDir() + + provider := credentialprovider.New(nil) + credentials, err := provider.Provide(context.Background(), volCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, credentials.Source()) + + env, err := credentials.Dump(writePath, testEnvPath) + assert.NoError(t, err) + + assert.Equals(t, envprovider.Environment{ + "AWS_PROFILE=s3-csi", + "AWS_CONFIG_FILE=/test-env/s3-csi-config", + "AWS_SHARED_CREDENTIALS_FILE=/test-env/s3-csi-credentials", + fmt.Sprintf("AWS_ROLE_ARN=%s", testRoleARN), + "AWS_WEB_IDENTITY_TOKEN_FILE=/test-env/serviceaccount.token", + }, env) + assertLongTermCredentials(t, writePath) + assertWebIdentityTokenFile(t, writePath) + } + }) + + t.Run("no credentials", func(t *testing.T) { + for _, volCtx := range volumeContextVariants { + provider := credentialprovider.New(nil) + credentials, err := provider.Provide(context.Background(), volCtx) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourceDriver, credentials.Source()) + + env, err := credentials.Dump(t.TempDir(), testEnvPath) + assert.NoError(t, err) + assert.Equals(t, envprovider.Environment{}, env) + } + }) +} + +func TestProvidingPodLevelCredentials(t *testing.T) { + t.Run("correct values", func(t *testing.T) { + clientset := fake.NewSimpleClientset(serviceAccount("test-sa", "test-ns", map[string]string{ + "eks.amazonaws.com/role-arn": testRoleARN, + })) + + provider := credentialprovider.New(clientset.CoreV1()) + credentials, err := provider.Provide(context.Background(), map[string]string{ + volumecontext.AuthenticationSource: credentialprovider.AuthenticationSourcePod, + volumecontext.CSIPodNamespace: "test-ns", + volumecontext.CSIServiceAccountName: "test-sa", + volumecontext.CSIServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + }) + assert.NoError(t, err) + assert.Equals(t, credentialprovider.AuthenticationSourcePod, credentials.Source()) + + writePath := t.TempDir() + + env, err := credentials.Dump(writePath, testEnvPath) + assert.NoError(t, err) + + assert.Equals(t, envprovider.Environment{ + fmt.Sprintf("AWS_ROLE_ARN=%s", testRoleARN), + "AWS_WEB_IDENTITY_TOKEN_FILE=/test-env/serviceaccount.token", + + // Having a unique cache key for namespace/serviceaccount pair + "UNSTABLE_MOUNTPOINT_CACHE_KEY=test-ns/test-sa", + + // Disable long-term credentials + "AWS_CONFIG_FILE=/test-env/disable-config", + "AWS_SHARED_CREDENTIALS_FILE=/test-env/disable-credentials", + + // Disable EC2 credentials + "AWS_EC2_METADATA_DISABLED=true", + }, env) + assertWebIdentityTokenFile(t, writePath) + }) + + t.Run("missing information", func(t *testing.T) { + clientset := fake.NewSimpleClientset( + serviceAccount("test-sa", "test-ns", map[string]string{ + "eks.amazonaws.com/role-arn": testRoleARN, + }), + serviceAccount("test-sa-missing-role", "test-ns", map[string]string{}), + ) + + for name, test := range map[string]struct { + volumeContext map[string]string + }{ + "unknown service account": { + volumeContext: map[string]string{ + volumecontext.AuthenticationSource: credentialprovider.AuthenticationSourcePod, + volumecontext.CSIPodNamespace: "test-ns", + volumecontext.CSIServiceAccountName: "test-unknown-sa", + volumecontext.CSIServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + }, + }, + "missing service account token": { + volumeContext: map[string]string{ + volumecontext.AuthenticationSource: credentialprovider.AuthenticationSourcePod, + volumecontext.CSIPodNamespace: "test-ns", + volumecontext.CSIServiceAccountName: "test-sa", + }, + }, + "missing sts audience in service account tokens": { + volumeContext: map[string]string{ + volumecontext.AuthenticationSource: credentialprovider.AuthenticationSourcePod, + volumecontext.CSIPodNamespace: "test-ns", + volumecontext.CSIServiceAccountName: "test-sa", + volumecontext.CSIServiceAccountTokens: serviceAccountTokens(t, tokens{ + "unknown": { + Token: testWebIdentityToken, + }, + }), + }, + }, + "missing service account name": { + volumeContext: map[string]string{ + volumecontext.AuthenticationSource: credentialprovider.AuthenticationSourcePod, + volumecontext.CSIPodNamespace: "test-ns", + volumecontext.CSIServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + }, + }, + "missing pod namespace": { + volumeContext: map[string]string{ + volumecontext.AuthenticationSource: credentialprovider.AuthenticationSourcePod, + volumecontext.CSIServiceAccountName: "test-sa", + volumecontext.CSIServiceAccountTokens: serviceAccountTokens(t, tokens{ + "sts.amazonaws.com": { + Token: testWebIdentityToken, + }, + }), + }, + }, + } { + t.Run(name, func(t *testing.T) { + provider := credentialprovider.New(clientset.CoreV1()) + _, err := provider.Provide(context.Background(), test.volumeContext) + if err == nil { + t.Error("it should fail with missing information") + } + }) + } + }) +} + +//-- Utilities for tests + +func setEnvForLongTermCredentials(t *testing.T) { + t.Setenv("AWS_ACCESS_KEY_ID", testAccessKeyID) + t.Setenv("AWS_SECRET_ACCESS_KEY", testSecretAccessKey) + t.Setenv("AWS_SESSION_TOKEN", testSessionToken) +} + +func assertLongTermCredentials(t *testing.T, basepath string) { + t.Helper() + + awsprofiletest.AssertCredentialsFromAWSProfile( + t, + "s3-csi", + filepath.Join(basepath, "s3-csi-config"), + filepath.Join(basepath, "s3-csi-credentials"), + "test-access-key-id", + "test-secret-access-key", + "test-session-token", + ) +} + +func setEnvForStsWebIdentityCredentials(t *testing.T) { + t.Helper() + + tokenPath := filepath.Join(t.TempDir(), "token") + assert.NoError(t, os.WriteFile(tokenPath, []byte(testWebIdentityToken), 0600)) + + t.Setenv("AWS_ROLE_ARN", testRoleARN) + t.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", tokenPath) +} + +func assertWebIdentityTokenFile(t *testing.T, basepath string) { + t.Helper() + + got, err := os.ReadFile(filepath.Join(basepath, "serviceaccount.token")) + assert.NoError(t, err) + assert.Equals(t, []byte(testWebIdentityToken), got) +} + +type tokens = map[string]struct { + Token string `json:"token"` + ExpirationTimestamp time.Time +} + +func serviceAccountTokens(t *testing.T, tokens tokens) string { + buf, err := json.Marshal(&tokens) + assert.NoError(t, err) + return string(buf) +} + +func serviceAccount(name, namespace string, annotations map[string]string) *v1.ServiceAccount { + return &v1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Annotations: annotations, + }} +} diff --git a/pkg/driver/node/envprovider/provider.go b/pkg/driver/node/envprovider/provider.go new file mode 100644 index 00000000..141f0991 --- /dev/null +++ b/pkg/driver/node/envprovider/provider.go @@ -0,0 +1,74 @@ +// Package envprovider provides utilities for accessing environment variables to pass Mountpoint. +package envprovider + +import ( + "fmt" + "os" + "slices" + "strings" +) + +const ( + EnvRegion = "AWS_REGION" + EnvDefaultRegion = "AWS_DEFAULT_REGION" + EnvSTSRegionalEndpoints = "AWS_STS_REGIONAL_ENDPOINTS" + EnvMaxAttempts = "AWS_MAX_ATTEMPTS" + EnvProfile = "AWS_PROFILE" + EnvConfigFile = "AWS_CONFIG_FILE" + EnvSharedCredentialsFile = "AWS_SHARED_CREDENTIALS_FILE" + EnvRoleARN = "AWS_ROLE_ARN" + EnvWebIdentityTokenFile = "AWS_WEB_IDENTITY_TOKEN_FILE" + EnvEC2MetadataDisabled = "AWS_EC2_METADATA_DISABLED" + + EnvMountpointCacheKey = "UNSTABLE_MOUNTPOINT_CACHE_KEY" +) + +// An Environment represents a list of environment variables. +type Environment = []string + +// envAllowlist is the list of environment variables to pass-by by default. +// If any of these set, it will be returned as-is in [Provide]. +var envAllowlist = []string{ + EnvRegion, + EnvDefaultRegion, + EnvSTSRegionalEndpoints, +} + +// Region returns detected region from environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`. +// It returns an empty string if both is unset. +func Region() string { + region := os.Getenv(EnvRegion) + if region != "" { + return region + } + return os.Getenv(EnvDefaultRegion) +} + +// Provide returns list of environment variables to pass Mountpoint. +func Provide() Environment { + environment := Environment{} + for _, key := range envAllowlist { + val := os.Getenv(key) + if val != "" { + environment = append(environment, Format(key, val)) + } + } + return environment +} + +// Format formats given key and value to be used as an environment variable. +func Format(key, value string) string { + return fmt.Sprintf("%s=%s", key, value) +} + +// Remove removes environment variable with given `key` from given environment variables `env`. +// It returns updated environment variables. +func Remove(env Environment, key string) Environment { + prefix := key + if !strings.HasSuffix(key, "=") { + prefix = prefix + "=" + } + return slices.DeleteFunc(env, func(k string) bool { + return strings.HasPrefix(k, prefix) + }) +} diff --git a/pkg/driver/node/envprovider/provider_test.go b/pkg/driver/node/envprovider/provider_test.go new file mode 100644 index 00000000..18534e25 --- /dev/null +++ b/pkg/driver/node/envprovider/provider_test.go @@ -0,0 +1,149 @@ +package envprovider_test + +import ( + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +func TestGettingRegion(t *testing.T) { + testCases := []struct { + name string + envRegion string + envDefaultRegion string + want string + }{ + { + name: "both region envs are set", + envRegion: "us-west-1", + envDefaultRegion: "us-east-1", + want: "us-west-1", + }, + { + name: "only default region env is set", + envRegion: "", + envDefaultRegion: "us-east-1", + want: "us-east-1", + }, + { + name: "no region env is set", + envRegion: "", + envDefaultRegion: "", + want: "", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Setenv("AWS_REGION", testCase.envRegion) + t.Setenv("AWS_DEFAULT_REGION", testCase.envDefaultRegion) + assert.Equals(t, testCase.want, envprovider.Region()) + }) + } +} + +func TestProvidingEnvironmentVariables(t *testing.T) { + testCases := []struct { + name string + env map[string]string + want []string + }{ + { + name: "no env vars set", + env: map[string]string{}, + want: []string{}, + }, + { + name: "some allowed env vars set", + env: map[string]string{ + "AWS_REGION": "us-west-1", + "AWS_DEFAULT_REGION": "us-east-1", + "AWS_STS_REGIONAL_ENDPOINTS": "regional", + "AWS_MAX_ATTEMPTS": "10", + }, + want: []string{ + "AWS_REGION=us-west-1", + "AWS_DEFAULT_REGION=us-east-1", + "AWS_STS_REGIONAL_ENDPOINTS=regional", + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + for k, v := range testCase.env { + t.Setenv(k, v) + } + assert.Equals(t, testCase.want, envprovider.Provide()) + }) + } +} + +func TestFormattingEnvironmentVariable(t *testing.T) { + testCases := []struct { + name string + key string + value string + want string + }{ + { + name: "region", + key: "AWS_REGION", + value: "us-west-1", + want: "AWS_REGION=us-west-1", + }, + { + name: "role arn", + key: "AWS_ROLE_ARN", + value: "arn:aws:iam::account:role/csi-driver-role-name", + want: "AWS_ROLE_ARN=arn:aws:iam::account:role/csi-driver-role-name", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + assert.Equals(t, testCase.want, envprovider.Format(testCase.key, testCase.value)) + }) + } +} + +func TestRemovingAKeyFromListOfEnvironmentVariables(t *testing.T) { + testCases := []struct { + name string + env envprovider.Environment + key string + want envprovider.Environment + }{ + { + name: "empty environment", + env: envprovider.Environment{}, + key: "AWS_REGION", + want: envprovider.Environment{}, + }, + { + name: "remove existing key", + env: envprovider.Environment{"AWS_REGION=us-west-1", "AWS_DEFAULT_REGION=us-east-1"}, + key: "AWS_REGION", + want: envprovider.Environment{"AWS_DEFAULT_REGION=us-east-1"}, + }, + { + name: "remove existing key with equals sign", + env: envprovider.Environment{"AWS_REGION=us-west-1", "AWS_DEFAULT_REGION=us-east-1"}, + key: "AWS_REGION=", + want: envprovider.Environment{"AWS_DEFAULT_REGION=us-east-1"}, + }, + { + name: "remove non-existing key", + env: envprovider.Environment{"AWS_REGION=us-west-1", "AWS_DEFAULT_REGION=us-east-1"}, + key: "AWS_MAX_ATTEMPTS", + want: envprovider.Environment{"AWS_REGION=us-west-1", "AWS_DEFAULT_REGION=us-east-1"}, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + assert.Equals(t, testCase.want, envprovider.Remove(testCase.env, testCase.key)) + }) + } +} diff --git a/pkg/driver/node/mounter/credential_provider.go b/pkg/driver/node/mounter/credential_provider.go deleted file mode 100644 index 12cc4a82..00000000 --- a/pkg/driver/node/mounter/credential_provider.go +++ /dev/null @@ -1,310 +0,0 @@ -package mounter - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "os" - "path" - "strings" - "sync" - "time" - - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" - "github.com/google/renameio" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - k8sv1 "k8s.io/client-go/kubernetes/typed/core/v1" - "k8s.io/klog/v2" - k8sstrings "k8s.io/utils/strings" - - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" -) - -const hostPluginDirEnv = "HOST_PLUGIN_DIR" - -type AuthenticationSource = string - -const ( - // This is when users don't provide a `authenticationSource` option in their volume attributes. - // We're defaulting to `driver` in this case. - AuthenticationSourceUnspecified AuthenticationSource = "" - AuthenticationSourceDriver AuthenticationSource = "driver" - AuthenticationSourcePod AuthenticationSource = "pod" -) - -const ( - // This is to ensure only owner/group can read the file and no one else. - serviceAccountTokenPerm = 0440 -) - -const defaultHostPluginDir = "/var/lib/kubelet/plugins/s3.csi.aws.com/" - -const serviceAccountTokenAudienceSTS = "sts.amazonaws.com" - -const serviceAccountRoleAnnotation = "eks.amazonaws.com/role-arn" - -const podLevelCredentialsDocsPage = "https://github.com/awslabs/mountpoint-s3-csi-driver/blob/main/docs/CONFIGURATION.md#pod-level-credentials" -const stsConfigDocsPage = "https://github.com/awslabs/mountpoint-s3-csi-driver/blob/main/docs/CONFIGURATION.md#configuring-the-sts-region" - -var errUnknownRegion = errors.New("NodePublishVolume: Pod-level: unknown region") - -type Token struct { - Token string `json:"token"` - ExpirationTimestamp time.Time `json:"expirationTimestamp"` -} - -type CredentialProvider struct { - client k8sv1.CoreV1Interface - containerPluginDir string - regionFromIMDS func() (string, error) -} - -func NewCredentialProvider(client k8sv1.CoreV1Interface, containerPluginDir string, regionFromIMDS func() (string, error)) *CredentialProvider { - // `regionFromIMDS` is a `sync.OnceValues` and it only makes request to IMDS once, - // this call is basically here to pre-warm the cache of IMDS call. - go func() { - _, _ = regionFromIMDS() - }() - - return &CredentialProvider{client, containerPluginDir, regionFromIMDS} -} - -// CleanupToken cleans any created service token files for given volume and pod. -func (c *CredentialProvider) CleanupToken(volumeID string, podID string) error { - err := os.Remove(c.tokenPathContainer(podID, volumeID)) - if err != nil && os.IsNotExist(err) { - return nil - } - return err -} - -// Provide provides mount credentials for given volume and volume context. -// Depending on the configuration, it either returns driver-level or pod-level credentials. -func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) { - if volumeCtx == nil { - return nil, status.Error(codes.InvalidArgument, "Missing volume context") - } - - authenticationSource := volumeCtx[volumecontext.AuthenticationSource] - switch authenticationSource { - case AuthenticationSourcePod: - return c.provideFromPod(ctx, volumeID, volumeCtx, mountpointArgs) - case AuthenticationSourceUnspecified, AuthenticationSourceDriver: - return c.provideFromDriver() - default: - return nil, fmt.Errorf("unknown `authenticationSource`: %s, only `driver` (default option if not specified) and `pod` supported", authenticationSource) - } -} - -func (c *CredentialProvider) provideFromDriver() (*MountCredentials, error) { - klog.V(4).Infof("NodePublishVolume: Using driver identity") - - hostPluginDir := hostPluginDirWithDefault() - hostTokenPath := path.Join(hostPluginDir, "token") - - return &MountCredentials{ - AuthenticationSource: AuthenticationSourceDriver, - AccessKeyID: os.Getenv(keyIdEnv), - SecretAccessKey: os.Getenv(accessKeyEnv), - SessionToken: os.Getenv(sessionTokenEnv), - Region: os.Getenv(regionEnv), - DefaultRegion: os.Getenv(defaultRegionEnv), - WebTokenPath: hostTokenPath, - StsEndpoints: os.Getenv(stsEndpointsEnv), - AwsRoleArn: os.Getenv(roleArnEnv), - }, nil -} - -func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) { - klog.V(4).Infof("NodePublishVolume: Using pod identity") - - tokensJson := volumeCtx[volumecontext.CSIServiceAccountTokens] - if tokensJson == "" { - klog.Error("`authenticationSource` configured to `pod` but no service account tokens are received. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage) - return nil, status.Error(codes.InvalidArgument, "Missing service account tokens") - } - - var tokens map[string]*Token - if err := json.Unmarshal([]byte(tokensJson), &tokens); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "Failed to parse service account tokens: %v", err) - } - - stsToken := tokens[serviceAccountTokenAudienceSTS] - if stsToken == nil { - klog.Errorf("`authenticationSource` configured to `pod` but no service account tokens for %s received. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage, serviceAccountTokenAudienceSTS) - return nil, status.Errorf(codes.InvalidArgument, "Missing service account token for %s", serviceAccountTokenAudienceSTS) - } - - awsRoleARN, err := c.findPodServiceAccountRole(ctx, volumeCtx) - if err != nil { - return nil, err - } - - region, err := c.stsRegion(volumeCtx, mountpointArgs) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "Failed to detect STS AWS Region, please explicitly set the AWS Region, see "+stsConfigDocsPage) - } - - defaultRegion := os.Getenv(defaultRegionEnv) - if defaultRegion == "" { - defaultRegion = region - } - - podID := volumeCtx[volumecontext.CSIPodUID] - if podID == "" { - return nil, status.Error(codes.InvalidArgument, "Missing Pod info. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage) - } - - err = c.writeToken(podID, volumeID, stsToken) - if err != nil { - return nil, status.Errorf(codes.Internal, "Failed to write service account token: %v", err) - } - - hostPluginDir := hostPluginDirWithDefault() - hostTokenPath := path.Join(hostPluginDir, c.tokenFilename(podID, volumeID)) - - podNamespace := volumeCtx[volumecontext.CSIPodNamespace] - podServiceAccount := volumeCtx[volumecontext.CSIServiceAccountName] - cacheKey := podNamespace + "/" + podServiceAccount - - return &MountCredentials{ - AuthenticationSource: AuthenticationSourcePod, - - Region: region, - DefaultRegion: defaultRegion, - StsEndpoints: os.Getenv(stsEndpointsEnv), - WebTokenPath: hostTokenPath, - AwsRoleArn: awsRoleARN, - - // Ensure to disable env credential provider - AccessKeyID: "", - SecretAccessKey: "", - - // Ensure to disable profile provider - ConfigFilePath: path.Join(hostPluginDir, "disable-config"), - SharedCredentialsFilePath: path.Join(hostPluginDir, "disable-credentials"), - - // Ensure to disable IMDS provider - DisableIMDSProvider: true, - - MountpointCacheKey: cacheKey, - }, nil -} - -func (c *CredentialProvider) writeToken(podID string, volumeID string, token *Token) error { - return renameio.WriteFile(c.tokenPathContainer(podID, volumeID), []byte(token.Token), serviceAccountTokenPerm) -} - -func (c *CredentialProvider) tokenPathContainer(podID string, volumeID string) string { - return path.Join(c.containerPluginDir, c.tokenFilename(podID, volumeID)) -} - -func (c *CredentialProvider) tokenFilename(podID string, volumeID string) string { - var filename strings.Builder - // `podID` is a UUID, but escape it to ensure it doesn't contain `/` - filename.WriteString(k8sstrings.EscapeQualifiedName(podID)) - filename.WriteRune('-') - // `volumeID` might contain `/`, we need to escape it - filename.WriteString(k8sstrings.EscapeQualifiedName(volumeID)) - filename.WriteString(".token") - return filename.String() -} - -func (c *CredentialProvider) findPodServiceAccountRole(ctx context.Context, volumeCtx map[string]string) (string, error) { - podNamespace := volumeCtx[volumecontext.CSIPodNamespace] - podServiceAccount := volumeCtx[volumecontext.CSIServiceAccountName] - if podNamespace == "" || podServiceAccount == "" { - klog.Error("`authenticationSource` configured to `pod` but no pod info found. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage) - return "", status.Error(codes.InvalidArgument, "Missing Pod info. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage) - } - - response, err := c.client.ServiceAccounts(podNamespace).Get(ctx, podServiceAccount, metav1.GetOptions{}) - if err != nil { - return "", status.Errorf(codes.InvalidArgument, "Failed to get pod's service account %s/%s: %v", podNamespace, podServiceAccount, err) - } - - roleArn := response.Annotations[serviceAccountRoleAnnotation] - if roleArn == "" { - klog.Error("`authenticationSource` configured to `pod` but pod's service account is not annotated with a role, see " + podLevelCredentialsDocsPage) - return "", status.Errorf(codes.InvalidArgument, "Missing role annotation on pod's service account %s/%s", podNamespace, podServiceAccount) - } - - return roleArn, nil -} - -// stsRegion tries to detect AWS region to use for STS. -// -// It looks for the following (in-order): -// 1. `stsRegion` passed via volume context -// 2. Region set for S3 bucket via mount options -// 3. `AWS_REGION` or `AWS_DEFAULT_REGION` env variables -// 4. Calling IMDS to detect region -// -// It returns an error if all of them fails. -func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, mountpointArgs []string) (string, error) { - region := volumeCtx[volumecontext.STSRegion] - if region != "" { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from volume context", region) - return region, nil - } - - if region, ok := ExtractMountpointArgument(mountpointArgs, mountpointArgRegion); ok { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from S3 bucket region", region) - return region, nil - } - - region = os.Getenv(regionEnv) - if region != "" { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from `AWS_REGION` env variable", region) - return region, nil - } - - region = os.Getenv(defaultRegionEnv) - if region != "" { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from `AWS_DEFAULT_REGION` env variable", region) - return region, nil - } - - // We're ignoring the error here, makes a call to IMDS only once and logs the error in case of error - region, _ = c.regionFromIMDS() - if region != "" { - klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from IMDS", region) - return region, nil - } - - return "", errUnknownRegion -} - -func hostPluginDirWithDefault() string { - hostPluginDir := os.Getenv(hostPluginDirEnv) - if hostPluginDir == "" { - hostPluginDir = defaultHostPluginDir - } - return hostPluginDir -} - -// RegionFromIMDSOnce tries to detect AWS region by making a request to IMDS. -// It only makes request to IMDS once and caches the value. -var RegionFromIMDSOnce = sync.OnceValues(func() (string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - cfg, err := config.LoadDefaultConfig(ctx) - if err != nil { - klog.V(5).Infof("NodePublishVolume: Pod-level: Failed to create config for IMDS client: %v", err) - return "", fmt.Errorf("could not create config for imds client: %w", err) - } - - client := imds.NewFromConfig(cfg) - output, err := client.GetRegion(ctx, &imds.GetRegionInput{}) - if err != nil { - klog.V(5).Infof("NodePublishVolume: Pod-level: Failed to get region from IMDS: %v", err) - return "", fmt.Errorf("failed to get region from imds: %w", err) - } - - return output.Region, nil -}) diff --git a/pkg/driver/node/mounter/credential_provider_test.go b/pkg/driver/node/mounter/credential_provider_test.go deleted file mode 100644 index 26f5a7c9..00000000 --- a/pkg/driver/node/mounter/credential_provider_test.go +++ /dev/null @@ -1,598 +0,0 @@ -package mounter_test - -import ( - "context" - "encoding/json" - "errors" - "os" - "path" - "testing" - "time" - - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" - - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes/fake" -) - -func TestProvidingDriverLevelCredentials(t *testing.T) { - t.Setenv("AWS_ACCESS_KEY_ID", "test-access-key") - t.Setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") - t.Setenv("AWS_SESSION_TOKEN", "test-session-token") - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - t.Setenv("HOST_PLUGIN_DIR", "/test/csi/plugin/dir") - t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") - t.Setenv("AWS_ROLE_ARN", "arn:aws:iam::123456789012:role/Test") - - for _, test := range []struct { - volumeID string - volumeContext map[string]string - }{ - { - volumeID: "test-vol-id", - volumeContext: map[string]string{"authenticationSource": "driver"}, - }, - { - volumeID: "test-vol-id", - // It should default to `driver` if `authenticationSource` is not explicitly set - volumeContext: map[string]string{}, - }, - } { - - provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce) - credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, nil) - assertEquals(t, nil, err) - - assertEquals(t, credentials.AccessKeyID, "test-access-key") - assertEquals(t, credentials.SecretAccessKey, "test-secret-key") - assertEquals(t, credentials.SessionToken, "test-session-token") - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - assertEquals(t, credentials.WebTokenPath, "/test/csi/plugin/dir/token") - assertEquals(t, credentials.StsEndpoints, "regional") - assertEquals(t, credentials.AwsRoleArn, "arn:aws:iam::123456789012:role/Test") - } -} - -func TestProvidingDriverLevelCredentialsWithEmptyEnv(t *testing.T) { - provider := mounter.NewCredentialProvider(nil, "", mounter.RegionFromIMDSOnce) - credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{"authenticationSource": "driver"}, nil) - assertEquals(t, nil, err) - - assertEquals(t, credentials.AccessKeyID, "") - assertEquals(t, credentials.SecretAccessKey, "") - assertEquals(t, credentials.SessionToken, "") - assertEquals(t, credentials.Region, "") - assertEquals(t, credentials.DefaultRegion, "") - assertEquals(t, credentials.WebTokenPath, "/var/lib/kubelet/plugins/s3.csi.aws.com/token") - assertEquals(t, credentials.StsEndpoints, "") - assertEquals(t, credentials.AwsRoleArn, "") -} - -func TestProvidingPodLevelCredentials(t *testing.T) { - pluginDir := t.TempDir() - clientset := fake.NewSimpleClientset(serviceAccount("test-sa", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test", - })) - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - t.Setenv("HOST_PLUGIN_DIR", "/test/csi/plugin/dir") - t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") - - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, mounter.RegionFromIMDSOnce) - - credentials, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, nil) - assertEquals(t, nil, err) - - // Should disable env variable provider - assertEquals(t, credentials.AccessKeyID, "") - assertEquals(t, credentials.SecretAccessKey, "") - assertEquals(t, credentials.SessionToken, "") - - // Should disable profile provider - assertEquals(t, credentials.ConfigFilePath, "/test/csi/plugin/dir/disable-config") - assertEquals(t, credentials.SharedCredentialsFilePath, "/test/csi/plugin/dir/disable-credentials") - - // Should disable IMDS provider - assertEquals(t, credentials.DisableIMDSProvider, true) - - // Should populate env variables for STS Web Identity provider - assertEquals(t, credentials.WebTokenPath, "/test/csi/plugin/dir/test-pod-test-vol-id.token") - assertEquals(t, credentials.AwsRoleArn, "arn:aws:iam::123456789012:role/Test") - - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - assertEquals(t, credentials.StsEndpoints, "regional") - - assertEquals(t, credentials.MountpointCacheKey, "test-ns/test-sa") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) -} - -func TestProvidingPodLevelCredentialsWithMissingInformation(t *testing.T) { - pluginDir := t.TempDir() - clientset := fake.NewSimpleClientset( - serviceAccount("test-sa", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test", - }), - serviceAccount("test-sa-missing-role", "test-ns", map[string]string{}), - ) - - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, mounter.RegionFromIMDSOnce) - - for name, test := range map[string]struct { - volumeID string - volumeContext map[string]string - }{ - "unknown service account": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-unknown-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, - }, - "missing service account token": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - }, - }, - "missing sts audience in service account tokens": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "unknown": { - Token: "test-service-account-token", - }, - }), - }, - }, - "missing service account name": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, - }, - "missing pod namespace": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, - }, - "missing pod id": { - volumeID: "test-vol-id", - volumeContext: map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, - }, - } { - t.Run(name, func(t *testing.T) { - credentials, err := provider.Provide(context.Background(), test.volumeID, test.volumeContext, nil) - assertEquals(t, nil, credentials) - if err == nil { - t.Error("it should fail with missing information") - } - - _, err = os.ReadFile(path.Join(pluginDir, "test-pod-test-vol-id.token")) - assertEquals(t, true, os.IsNotExist(err)) - }) - } -} - -func TestProvidingPodLevelCredentialsRegionPopulation(t *testing.T) { - clientset := fake.NewSimpleClientset(serviceAccount("test-sa", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test", - })) - - volumeID := "test-vol-id" - volumeContext := map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - } - - t.Run("no region", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "", errors.New("unknown region") - }) - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) - assertEquals(t, nil, credentials) - if err == nil { - t.Error("it should fail if there is not any region information") - } - - _, err = os.ReadFile(path.Join(pluginDir, "test-pod-test-vol-id.token")) - assertEquals(t, true, os.IsNotExist(err)) - }) - - t.Run("region from imds", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "us-east-1") - assertEquals(t, credentials.DefaultRegion, "us-east-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-west-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("default region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_DEFAULT_REGION", "eu-west-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-west-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("default and regular region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, nil) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from mountpoint options", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "us-west-1") - assertEquals(t, credentials.DefaultRegion, "us-west-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("missing region from mountpoint options", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--read-only"}) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-west-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from mountpoint options with default region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "us-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from volume context", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - - volumeContext["stsRegion"] = "ap-south-1" - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "ap-south-1") - assertEquals(t, credentials.DefaultRegion, "ap-south-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) - - t.Run("region from volume context with default region from env", func(t *testing.T) { - pluginDir := t.TempDir() - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, func() (string, error) { - return "us-east-1", nil - }) - - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - - volumeContext["stsRegion"] = "ap-south-1" - - credentials, err := provider.Provide(context.Background(), volumeID, volumeContext, []string{"--region=us-west-1"}) - assertEquals(t, nil, err) - assertEquals(t, credentials.Region, "ap-south-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) - }) -} - -func TestProvidingPodLevelCredentialsForDifferentPodsWithDifferentRoles(t *testing.T) { - pluginDir := t.TempDir() - clientset := fake.NewSimpleClientset( - serviceAccount("test-sa-1", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test1", - }), - serviceAccount("test-sa-2", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test2", - }), - ) - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - t.Setenv("HOST_PLUGIN_DIR", "/test/csi/plugin/dir") - t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") - - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, mounter.RegionFromIMDSOnce) - - credentialsPodOne, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod-1", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa-1", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token-1", - }, - }), - }, nil) - assertEquals(t, nil, err) - - credentialsPodTwo, err := provider.Provide(context.Background(), "test-vol-id", map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod-2", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa-2", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token-2", - }, - }), - }, nil) - assertEquals(t, nil, err) - - // PodOne - assertEquals(t, credentialsPodOne.AccessKeyID, "") - assertEquals(t, credentialsPodOne.SecretAccessKey, "") - assertEquals(t, credentialsPodOne.SessionToken, "") - assertEquals(t, credentialsPodOne.Region, "eu-west-1") - assertEquals(t, credentialsPodOne.DefaultRegion, "eu-north-1") - assertEquals(t, credentialsPodOne.WebTokenPath, "/test/csi/plugin/dir/test-pod-1-test-vol-id.token") - assertEquals(t, credentialsPodOne.StsEndpoints, "regional") - assertEquals(t, credentialsPodOne.AwsRoleArn, "arn:aws:iam::123456789012:role/Test1") - assertEquals(t, credentialsPodOne.MountpointCacheKey, "test-ns/test-sa-1") - - token, err := os.ReadFile(tokenFilePath(credentialsPodOne, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token-1", string(token)) - - // PodTwo - assertEquals(t, credentialsPodTwo.AccessKeyID, "") - assertEquals(t, credentialsPodTwo.SecretAccessKey, "") - assertEquals(t, credentialsPodTwo.SessionToken, "") - assertEquals(t, credentialsPodTwo.Region, "eu-west-1") - assertEquals(t, credentialsPodTwo.DefaultRegion, "eu-north-1") - assertEquals(t, credentialsPodTwo.WebTokenPath, "/test/csi/plugin/dir/test-pod-2-test-vol-id.token") - assertEquals(t, credentialsPodTwo.StsEndpoints, "regional") - assertEquals(t, credentialsPodTwo.AwsRoleArn, "arn:aws:iam::123456789012:role/Test2") - assertEquals(t, credentialsPodTwo.MountpointCacheKey, "test-ns/test-sa-2") - - token, err = os.ReadFile(tokenFilePath(credentialsPodTwo, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token-2", string(token)) -} - -func TestProvidingPodLevelCredentialsWithSlashInVolumeID(t *testing.T) { - pluginDir := t.TempDir() - clientset := fake.NewSimpleClientset(serviceAccount("test-sa", "test-ns", map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::123456789012:role/Test", - })) - t.Setenv("AWS_REGION", "eu-west-1") - t.Setenv("AWS_DEFAULT_REGION", "eu-north-1") - t.Setenv("HOST_PLUGIN_DIR", "/test/csi/plugin/dir") - t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "regional") - - provider := mounter.NewCredentialProvider(clientset.CoreV1(), pluginDir, mounter.RegionFromIMDSOnce) - - credentials, err := provider.Provide(context.Background(), "test-vol-id/1", map[string]string{ - "authenticationSource": "pod", - "csi.storage.k8s.io/pod.uid": "test-pod", - "csi.storage.k8s.io/pod.namespace": "test-ns", - "csi.storage.k8s.io/serviceAccount.name": "test-sa", - "csi.storage.k8s.io/serviceAccount.tokens": serviceAccountTokens(t, tokens{ - "sts.amazonaws.com": { - Token: "test-service-account-token", - }, - }), - }, nil) - assertEquals(t, nil, err) - - assertEquals(t, credentials.AccessKeyID, "") - assertEquals(t, credentials.SecretAccessKey, "") - assertEquals(t, credentials.SessionToken, "") - assertEquals(t, credentials.Region, "eu-west-1") - assertEquals(t, credentials.DefaultRegion, "eu-north-1") - assertEquals(t, credentials.WebTokenPath, "/test/csi/plugin/dir/test-pod-test-vol-id~1.token") - assertEquals(t, credentials.StsEndpoints, "regional") - assertEquals(t, credentials.AwsRoleArn, "arn:aws:iam::123456789012:role/Test") - - token, err := os.ReadFile(tokenFilePath(credentials, pluginDir)) - assertEquals(t, nil, err) - assertEquals(t, "test-service-account-token", string(token)) -} - -func TestCleaningUpTokenFileForAVolume(t *testing.T) { - t.Run("existing token", func(t *testing.T) { - pluginDir := t.TempDir() - volumeID := "test-vol-id" - podID := "test-pod-id" - tokenPath := path.Join(pluginDir, podID+"-"+volumeID+".token") - err := os.WriteFile(tokenPath, []byte("test-service-account-token"), 0400) - assertEquals(t, nil, err) - - provider := mounter.NewCredentialProvider(nil, pluginDir, mounter.RegionFromIMDSOnce) - err = provider.CleanupToken(volumeID, podID) - assertEquals(t, nil, err) - - _, err = os.ReadFile(tokenPath) - assertEquals(t, true, os.IsNotExist(err)) - }) - - t.Run("non-existing token", func(t *testing.T) { - provider := mounter.NewCredentialProvider(nil, t.TempDir(), mounter.RegionFromIMDSOnce) - - err := provider.CleanupToken("non-existing-vol-id", "non-existing-pod-id") - assertEquals(t, nil, err) - }) -} - -type tokens = map[string]struct { - Token string `json:"token"` - ExpirationTimestamp time.Time -} - -func serviceAccountTokens(t *testing.T, tokens tokens) string { - buf, err := json.Marshal(&tokens) - assertEquals(t, nil, err) - return string(buf) -} - -func serviceAccount(name, namespace string, annotations map[string]string) *v1.ServiceAccount { - return &v1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: namespace, - Annotations: annotations, - }} -} - -func tokenFilePath(credentials *mounter.MountCredentials, pluginDir string) string { - return path.Join(pluginDir, path.Base(credentials.WebTokenPath)) -} - -func assertEquals[T comparable](t *testing.T, expected T, got T) { - if expected != got { - t.Errorf("Expected %#v, Got %#v", expected, got) - } -} diff --git a/pkg/driver/node/mounter/fake_mounter.go b/pkg/driver/node/mounter/fake_mounter.go index a33a3ec0..7af82a6b 100644 --- a/pkg/driver/node/mounter/fake_mounter.go +++ b/pkg/driver/node/mounter/fake_mounter.go @@ -1,9 +1,14 @@ package mounter +import ( + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" +) + type FakeMounter struct{} -func (m *FakeMounter) Mount(bucketName string, target string, - credentials *MountCredentials, options []string) error { +func (m *FakeMounter) Mount(bucketName string, target string, credentials credentialprovider.Credentials, env envprovider.Environment, args mountpoint.Args) error { return nil } diff --git a/pkg/driver/node/mounter/mocks/mock_mount.go b/pkg/driver/node/mounter/mocks/mock_mount.go index 6d6053cc..1dcc8b66 100644 --- a/pkg/driver/node/mounter/mocks/mock_mount.go +++ b/pkg/driver/node/mounter/mocks/mock_mount.go @@ -8,7 +8,9 @@ import ( context "context" reflect "reflect" - mounter "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + credentialprovider "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + envprovider "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + mountpoint "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" system "github.com/awslabs/aws-s3-csi-driver/pkg/system" gomock "github.com/golang/mock/gomock" ) @@ -105,17 +107,17 @@ func (mr *MockMounterMockRecorder) IsMountPoint(target interface{}) *gomock.Call } // Mount mocks base method. -func (m *MockMounter) Mount(bucketName, target string, credentials *mounter.MountCredentials, options []string) error { +func (m *MockMounter) Mount(bucketName, target string, credentials credentialprovider.Credentials, env envprovider.Environment, args mountpoint.Args) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Mount", bucketName, target, credentials, options) + ret := m.ctrl.Call(m, "Mount", bucketName, target, credentials, env, args) ret0, _ := ret[0].(error) return ret0 } // Mount indicates an expected call of Mount. -func (mr *MockMounterMockRecorder) Mount(bucketName, target, credentials, options interface{}) *gomock.Call { +func (mr *MockMounterMockRecorder) Mount(bucketName, target, credentials, env, args interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), bucketName, target, credentials, options) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), bucketName, target, credentials, env, args) } // Unmount mocks base method. diff --git a/pkg/driver/node/mounter/mount_credentials.go b/pkg/driver/node/mounter/mount_credentials.go deleted file mode 100644 index 92cdf010..00000000 --- a/pkg/driver/node/mounter/mount_credentials.go +++ /dev/null @@ -1,101 +0,0 @@ -package mounter - -import "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" - -const ( - awsProfileEnv = "AWS_PROFILE" - awsConfigFileEnv = "AWS_CONFIG_FILE" - awsSharedCredentialsFileEnv = "AWS_SHARED_CREDENTIALS_FILE" - keyIdEnv = "AWS_ACCESS_KEY_ID" - accessKeyEnv = "AWS_SECRET_ACCESS_KEY" - sessionTokenEnv = "AWS_SESSION_TOKEN" - disableIMDSProviderEnv = "AWS_EC2_METADATA_DISABLED" - regionEnv = "AWS_REGION" - defaultRegionEnv = "AWS_DEFAULT_REGION" - stsEndpointsEnv = "AWS_STS_REGIONAL_ENDPOINTS" - roleArnEnv = "AWS_ROLE_ARN" - webIdentityTokenEnv = "AWS_WEB_IDENTITY_TOKEN_FILE" - MountS3PathEnv = "MOUNT_S3_PATH" - awsMaxAttemptsEnv = "AWS_MAX_ATTEMPTS" - MountpointCacheKey = "UNSTABLE_MOUNTPOINT_CACHE_KEY" - defaultMountS3Path = "/usr/bin/mount-s3" - userAgentPrefix = "--user-agent-prefix" - awsMaxAttemptsOption = "--aws-max-attempts" -) - -type MountCredentials struct { - // Identifies how these credentials are obtained. - AuthenticationSource AuthenticationSource - - // -- Env variable provider - AccessKeyID string - SecretAccessKey string - SessionToken string - - // -- Profile provider - ConfigFilePath string - SharedCredentialsFilePath string - - // -- STS provider - WebTokenPath string - AwsRoleArn string - - // -- IMDS provider - DisableIMDSProvider bool - - // -- Generic - Region string - DefaultRegion string - StsEndpoints string - - // -- TODO - Move somewhere better - MountpointCacheKey string -} - -// Get environment variables to pass to mount-s3 for authentication. -func (mc *MountCredentials) Env(awsProfile awsprofile.AWSProfile) []string { - env := []string{} - - // For profile provider from long-term credentials - if awsProfile.Name != "" { - env = append(env, awsProfileEnv+"="+awsProfile.Name) - env = append(env, awsConfigFileEnv+"="+awsProfile.ConfigPath) - env = append(env, awsSharedCredentialsFileEnv+"="+awsProfile.CredentialsPath) - } else { - // For profile provider - if mc.ConfigFilePath != "" { - env = append(env, awsConfigFileEnv+"="+mc.ConfigFilePath) - } - if mc.SharedCredentialsFilePath != "" { - env = append(env, awsSharedCredentialsFileEnv+"="+mc.SharedCredentialsFilePath) - } - } - - // For STS Web Identity provider - if mc.WebTokenPath != "" { - env = append(env, webIdentityTokenEnv+"="+mc.WebTokenPath) - env = append(env, roleArnEnv+"="+mc.AwsRoleArn) - } - - // For disabling IMDS provider - if mc.DisableIMDSProvider { - env = append(env, disableIMDSProviderEnv+"=true") - } - - // Generic variables - if mc.Region != "" { - env = append(env, regionEnv+"="+mc.Region) - } - if mc.DefaultRegion != "" { - env = append(env, defaultRegionEnv+"="+mc.DefaultRegion) - } - if mc.StsEndpoints != "" { - env = append(env, stsEndpointsEnv+"="+mc.StsEndpoints) - } - - if mc.MountpointCacheKey != "" { - env = append(env, MountpointCacheKey+"="+mc.MountpointCacheKey) - } - - return env -} diff --git a/pkg/driver/node/mounter/mounter.go b/pkg/driver/node/mounter/mounter.go index ece7b7c2..71bf346f 100644 --- a/pkg/driver/node/mounter/mounter.go +++ b/pkg/driver/node/mounter/mounter.go @@ -5,6 +5,9 @@ import ( "context" "os" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" "github.com/awslabs/aws-s3-csi-driver/pkg/system" ) @@ -15,11 +18,13 @@ type ServiceRunner interface { // Mounter is an interface for mount operations type Mounter interface { - Mount(bucketName string, target string, credentials *MountCredentials, options []string) error + Mount(bucketName string, target string, credentials credentialprovider.Credentials, env envprovider.Environment, args mountpoint.Args) error Unmount(target string) error IsMountPoint(target string) (bool, error) } +const MountS3PathEnv = "MOUNT_S3_PATH" + func MountS3Path() string { mountS3Path := os.Getenv(MountS3PathEnv) if mountS3Path == "" { diff --git a/pkg/driver/node/mounter/systemd_mounter.go b/pkg/driver/node/mounter/systemd_mounter.go index e6c0db66..57fd004c 100644 --- a/pkg/driver/node/mounter/systemd_mounter.go +++ b/pkg/driver/node/mounter/systemd_mounter.go @@ -2,44 +2,48 @@ package mounter import ( "context" + "errors" "fmt" + "io/fs" "os" "path/filepath" - "strings" "time" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" - "github.com/awslabs/aws-s3-csi-driver/pkg/system" "github.com/google/uuid" "k8s.io/klog/v2" "k8s.io/mount-utils" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" + "github.com/awslabs/aws-s3-csi-driver/pkg/system" ) +var defaultMountS3Path = "/usr/bin/mount-s3" + // https://github.com/awslabs/mountpoint-s3/blob/9ed8b6243f4511e2013b2f4303a9197c3ddd4071/mountpoint-s3/src/cli.rs#L421 const mountpointDeviceName = "mountpoint-s3" type SystemdMounter struct { - Ctx context.Context - Runner ServiceRunner - Mounter mount.Interface - MpVersion string - MountS3Path string - kubernetesVersion string + Ctx context.Context + Runner ServiceRunner + Mounter mount.Interface + MpVersion string + MountS3Path string } -func NewSystemdMounter(mpVersion string, kubernetesVersion string) (*SystemdMounter, error) { +func NewSystemdMounter(mpVersion string) (*SystemdMounter, error) { ctx := context.Background() runner, err := system.StartOsSystemdSupervisor() if err != nil { return nil, fmt.Errorf("failed to start systemd supervisor: %w", err) } return &SystemdMounter{ - Ctx: ctx, - Runner: runner, - Mounter: mount.New(""), - MpVersion: mpVersion, - MountS3Path: MountS3Path(), - kubernetesVersion: kubernetesVersion, + Ctx: ctx, + Runner: runner, + Mounter: mount.New(""), + MpVersion: mpVersion, + MountS3Path: MountS3Path(), }, nil } @@ -79,7 +83,7 @@ func (m *SystemdMounter) IsMountPoint(target string) (bool, error) { // // This method will create the target path if it does not exist and if there is an existing corrupt // mount, it will attempt an unmount before attempting the mount. -func (m *SystemdMounter) Mount(bucketName string, target string, credentials *MountCredentials, options []string) error { +func (m *SystemdMounter) Mount(bucketName string, target string, credentials credentialprovider.Credentials, env envprovider.Environment, args mountpoint.Args) error { if bucketName == "" { return fmt.Errorf("bucket name is empty") } @@ -122,39 +126,31 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo return fmt.Errorf("Could not check if %q is a mount point: %v, %v", target, statErr, err) } - if isMountPoint { - klog.V(4).Infof("NodePublishVolume: Target path %q is already mounted", target) - return nil + credentialsBasepath, err := m.ensureCredentialsDirExists(target) + if err != nil { + return err } - env := []string{} - var authenticationSource AuthenticationSource - if credentials != nil { - var awsProfile awsprofile.AWSProfile - if credentials.AccessKeyID != "" && credentials.SecretAccessKey != "" { - // Kubernetes creates target path in the form of "/var/lib/kubelet/pods//volumes/kubernetes.io~csi//mount". - // So the directory of the target path is unique for this mount, and we can use it to write credentials and config files. - // These files will be cleaned up in `Unmount`. - basepath := filepath.Dir(target) - awsProfile, err = awsprofile.CreateAWSProfile(basepath, credentials.AccessKeyID, credentials.SecretAccessKey, credentials.SessionToken) - if err != nil { - klog.V(4).Infof("Mount: Failed to create AWS Profile in %s: %v", basepath, err) - return fmt.Errorf("Mount: Failed to create AWS Profile in %s: %v", basepath, err) - } - } + // Note that this part happens before `isMountPoint` check, as we want to update credentials even though + // there is an existing mount point at `target`. + credentialsEnv, err := credentials.Dump(credentialsBasepath, credentialsBasepath) + if err != nil { + klog.V(4).Infof("NodePublishVolume: Failed to dump credentials for %s: %v", target, err) + return err + } - authenticationSource = credentials.AuthenticationSource + env = append(env, credentialsEnv...) - env = credentials.Env(awsProfile) + if isMountPoint { + klog.V(4).Infof("NodePublishVolume: Target path %q is already mounted", target) + return nil } - options, env = moveOptionToEnvironmentVariables(awsMaxAttemptsOption, awsMaxAttemptsEnv, options, env) - options = addUserAgentToOptions(options, UserAgent(authenticationSource, m.kubernetesVersion)) output, err := m.Runner.StartService(timeoutCtx, &system.ExecConfig{ Name: "mount-s3-" + m.MpVersion + "-" + uuid.New().String() + ".service", Description: "Mountpoint for Amazon S3 CSI driver FUSE daemon", ExecPath: m.MountS3Path, - Args: append(options, bucketName, target), + Args: append(args.SortedList(), bucketName, target), Env: env, }) @@ -168,48 +164,10 @@ func (m *SystemdMounter) Mount(bucketName string, target string, credentials *Mo return nil } -// Moves a parameter optionName from the options list to MP's environment variable list. We need this as options are -// passed to the driver in a single field, but MP sometimes only supports config from environment variables. -// Returns an updated options and environment. -func moveOptionToEnvironmentVariables(optionName string, envName string, options []string, env []string) ([]string, []string) { - optionIdx := -1 - for i, o := range options { - if strings.HasPrefix(o, optionName) { - optionIdx = i - break - } - } - if optionIdx != -1 { - // We can do replace here as we've just verified it has the right prefix - env = append(env, strings.Replace(options[optionIdx], optionName, envName, 1)) - options = append(options[:optionIdx], options[optionIdx+1:]...) - } - return options, env -} - -// method to add the user agent prefix to the Mountpoint headers -// https://github.com/awslabs/mountpoint-s3/pull/548 -func addUserAgentToOptions(options []string, userAgent string) []string { - // first remove it from the options in case it's in there - for i := len(options) - 1; i >= 0; i-- { - if strings.Contains(options[i], userAgentPrefix) { - options = append(options[:i], options[i+1:]...) - } - } - // add the hard coded S3 CSI driver user agent string - return append(options, userAgentPrefix+"="+userAgent) -} - func (m *SystemdMounter) Unmount(target string) error { timeoutCtx, cancel := context.WithTimeout(m.Ctx, 30*time.Second) defer cancel() - basepath := filepath.Dir(target) - err := awsprofile.CleanupAWSProfile(basepath) - if err != nil { - klog.V(4).Infof("Unmount: Failed to clean up AWS Profile in %s: %v", basepath, err) - } - output, err := m.Runner.RunOneshot(timeoutCtx, &system.ExecConfig{ Name: "mount-s3-umount-" + uuid.New().String() + ".service", Description: "Mountpoint for Amazon S3 CSI driver unmount", @@ -222,23 +180,35 @@ func (m *SystemdMounter) Unmount(target string) error { if output != "" { klog.V(5).Infof("umount output: %s", output) } + + credentialsBasepath := m.credentialsDir(target) + err = os.RemoveAll(credentialsBasepath) + if err != nil { + klog.V(5).Infof("NodePublishVolume: Failed to clean up credentials for %s: %v", target, err) + return nil + } + return nil } -const ( - mountpointArgRegion = "region" - mountpointArgCache = "cache" -) - -// ExtractMountpointArgument extracts value of a given argument from `mountpointArgs`. -func ExtractMountpointArgument(mountpointArgs []string, argument string) (string, bool) { - // `mountpointArgs` normalized to `--arg=val` in `S3NodeServer.NodePublishVolume`. - prefix := fmt.Sprintf("--%s=", argument) - for _, arg := range mountpointArgs { - if strings.HasPrefix(arg, prefix) { - val := strings.SplitN(arg, "=", 2)[1] - return val, true - } +// ensureCredentialsDirExists ensures credentials dir for `target` is exists. +// It returns credentials dir and any error. +func (m *SystemdMounter) ensureCredentialsDirExists(target string) (string, error) { + credentialsBasepath := m.credentialsDir(target) + err := os.Mkdir(credentialsBasepath, credentialprovider.CredentialDirPerm) + if err != nil && !errors.Is(err, fs.ErrExist) { + klog.V(4).Infof("NodePublishVolume: Failed to create credentials directory for %s: %v", target, err) + return "", err } - return "", false + + return credentialsBasepath, nil +} + +// credentialsDir returns a directory to write credentials to for given `target`. +// +// Kubernetes creates target path in the form of "/var/lib/kubelet/pods//volumes/kubernetes.io~csi//mount". +// So, the directory of the target path is unique for this mount, and we're creating a new folder in this path for credentials. +// The credentials folder and all its contents will be cleaned up in `Unmount`. +func (m *SystemdMounter) credentialsDir(target string) string { + return filepath.Join(filepath.Dir(target), "credentials") } diff --git a/pkg/driver/node/mounter/systemd_mounter_test.go b/pkg/driver/node/mounter/systemd_mounter_test.go index 86c89689..06c45e28 100644 --- a/pkg/driver/node/mounter/systemd_mounter_test.go +++ b/pkg/driver/node/mounter/systemd_mounter_test.go @@ -5,17 +5,16 @@ import ( "errors" "os" "path/filepath" - "reflect" - "slices" - "strings" "testing" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/awsprofile" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" - mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" - "github.com/awslabs/aws-s3-csi-driver/pkg/system" "github.com/golang/mock/gomock" "k8s.io/mount-utils" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" ) type mounterTestEnv struct { @@ -49,86 +48,29 @@ func initMounterTestEnv(t *testing.T) *mounterTestEnv { func TestS3MounterMount(t *testing.T) { testBucketName := "test-bucket" testTargetPath := filepath.Join(t.TempDir(), "mount") - testCredentials := &mounter.MountCredentials{ - AccessKeyID: "test-access-key", - SecretAccessKey: "test-secret-key", - Region: "test-region", - DefaultRegion: "test-region", - WebTokenPath: "test-web-token-path", - StsEndpoints: "test-sts-endpoint", - AwsRoleArn: "test-aws-role", - } testCases := []struct { name string bucketName string targetPath string - credentials *mounter.MountCredentials - options []string + args mountpoint.Args expectedErr bool before func(*testing.T, *mounterTestEnv) }{ { - name: "success: mounts with empty options", - bucketName: testBucketName, - targetPath: testTargetPath, - credentials: testCredentials, - options: []string{}, - before: func(t *testing.T, env *mounterTestEnv) { - env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) - }, - }, - { - name: "success: mounts with nil credentials", - bucketName: testBucketName, - targetPath: testTargetPath, - credentials: nil, - options: []string{}, + name: "success: mounts with empty options", + bucketName: testBucketName, + targetPath: testTargetPath, + args: mountpoint.ParseArgs(nil), before: func(t *testing.T, env *mounterTestEnv) { env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) }, }, - { - name: "success: replaces user agent prefix", - bucketName: testBucketName, - targetPath: testTargetPath, - credentials: nil, - options: []string{"--user-agent-prefix=mycustomuseragent"}, - before: func(t *testing.T, env *mounterTestEnv) { - env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, config *system.ExecConfig) (string, error) { - for _, a := range config.Args { - if strings.Contains(a, "mycustomuseragent") { - t.Fatal("Bad user agent") - } - } - return "success", nil - }) - }, - }, - { - name: "success: aws max attempts", - bucketName: testBucketName, - targetPath: testTargetPath, - credentials: nil, - options: []string{"--aws-max-attempts=10"}, - before: func(t *testing.T, env *mounterTestEnv) { - env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, config *system.ExecConfig) (string, error) { - for _, e := range config.Env { - if e == "AWS_MAX_ATTEMPTS=10" { - return "success", nil - } - } - t.Fatal("Bad env") - return "", nil - }) - }, - }, { name: "failure: fails on mount failure", bucketName: testBucketName, targetPath: testTargetPath, - credentials: nil, - options: []string{}, + args: mountpoint.ParseArgs(nil), expectedErr: true, before: func(t *testing.T, env *mounterTestEnv) { env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("fail", errors.New("test failure")) @@ -137,26 +79,29 @@ func TestS3MounterMount(t *testing.T) { { name: "failure: won't mount empty bucket name", targetPath: testTargetPath, - credentials: testCredentials, - options: []string{}, + args: mountpoint.ParseArgs(nil), expectedErr: true, }, { name: "failure: won't mount empty target", bucketName: testBucketName, - credentials: testCredentials, - options: []string{}, + args: mountpoint.ParseArgs(nil), expectedErr: true, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { + provider := credentialprovider.New(nil) + credentials, err := provider.Provide(context.Background(), map[string]string{}) + assert.NoError(t, err) + env := initMounterTestEnv(t) if testCase.before != nil { testCase.before(t, env) } - err := env.mounter.Mount(testCase.bucketName, testCase.targetPath, - testCase.credentials, testCase.options) + + err = env.mounter.Mount(testCase.bucketName, testCase.targetPath, + credentials, nil, testCase.args) env.mockCtl.Finish() if err != nil && !testCase.expectedErr { t.Fatal(err) @@ -165,154 +110,6 @@ func TestS3MounterMount(t *testing.T) { } } -func TestProvidingEnvVariablesForMountpointProcess(t *testing.T) { - tests := map[string]struct { - profile awsprofile.AWSProfile - credentials *mounter.MountCredentials - expected []string - }{ - "Profile Provider for long-term credentials": { - profile: awsprofile.AWSProfile{ - Name: "profile", - ConfigPath: "~/.aws/s3-csi-config", - CredentialsPath: "~/.aws/s3-csi-credentials", - }, - credentials: &mounter.MountCredentials{}, - expected: []string{ - "AWS_PROFILE=profile", - "AWS_CONFIG_FILE=~/.aws/s3-csi-config", - "AWS_SHARED_CREDENTIALS_FILE=~/.aws/s3-csi-credentials", - }, - }, - "Profile Provider": { - credentials: &mounter.MountCredentials{ - ConfigFilePath: "~/.aws/config", - SharedCredentialsFilePath: "~/.aws/credentials", - }, - expected: []string{ - "AWS_CONFIG_FILE=~/.aws/config", - "AWS_SHARED_CREDENTIALS_FILE=~/.aws/credentials", - }, - }, - "Disabling IMDS Provider": { - credentials: &mounter.MountCredentials{ - DisableIMDSProvider: true, - }, - expected: []string{ - "AWS_EC2_METADATA_DISABLED=true", - }, - }, - "STS Web Identity Provider": { - credentials: &mounter.MountCredentials{ - WebTokenPath: "/path/to/web/token", - AwsRoleArn: "arn:aws:iam::123456789012:role/Role", - }, - expected: []string{ - "AWS_WEB_IDENTITY_TOKEN_FILE=/path/to/web/token", - "AWS_ROLE_ARN=arn:aws:iam::123456789012:role/Role", - }, - }, - "Region and Default Region": { - credentials: &mounter.MountCredentials{ - Region: "us-west-2", - DefaultRegion: "us-east-1", - }, - expected: []string{ - "AWS_REGION=us-west-2", - "AWS_DEFAULT_REGION=us-east-1", - }, - }, - "STS Endpoints": { - credentials: &mounter.MountCredentials{ - StsEndpoints: "regional", - }, - expected: []string{ - "AWS_STS_REGIONAL_ENDPOINTS=regional", - }, - }, - "Mountpoint Cache Key": { - credentials: &mounter.MountCredentials{ - MountpointCacheKey: "test_cache_key", - }, - expected: []string{ - "UNSTABLE_MOUNTPOINT_CACHE_KEY=test_cache_key", - }, - }, - "All Combined": { - credentials: &mounter.MountCredentials{ - WebTokenPath: "/path/to/web/token", - AwsRoleArn: "arn:aws:iam::123456789012:role/Role", - Region: "us-west-2", - DefaultRegion: "us-east-1", - StsEndpoints: "legacy", - ConfigFilePath: "~/.aws/config", - SharedCredentialsFilePath: "~/.aws/credentials", - DisableIMDSProvider: true, - MountpointCacheKey: "test/cache/key", - }, - expected: []string{ - "AWS_WEB_IDENTITY_TOKEN_FILE=/path/to/web/token", - "AWS_ROLE_ARN=arn:aws:iam::123456789012:role/Role", - "AWS_REGION=us-west-2", - "AWS_DEFAULT_REGION=us-east-1", - "AWS_STS_REGIONAL_ENDPOINTS=legacy", - "AWS_EC2_METADATA_DISABLED=true", - "AWS_CONFIG_FILE=~/.aws/config", - "AWS_SHARED_CREDENTIALS_FILE=~/.aws/credentials", - "UNSTABLE_MOUNTPOINT_CACHE_KEY=test/cache/key", - }, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - actual := test.credentials.Env(test.profile) - - slices.Sort(test.expected) - slices.Sort(actual) - - if !reflect.DeepEqual(actual, test.expected) { - t.Errorf("Expected %v, but got %v", test.expected, actual) - } - }) - } -} - -func TestExtractMountpointArgument(t *testing.T) { - for name, test := range map[string]struct { - input []string - argument string - expectedToFound bool - expectedValue string - }{ - "Extract Existing Argument": { - input: []string{ - "--region=us-east-1", - }, - argument: "region", - expectedToFound: true, - expectedValue: "us-east-1", - }, - "Extract Non Existing Argument": { - input: []string{ - "--bucket=test", - }, - argument: "region", - expectedToFound: false, - }, - "Extract Non Existing Argument With Empty Input": { - argument: "region", - expectedToFound: false, - }, - } { - t.Run(name, func(t *testing.T) { - val, found := mounter.ExtractMountpointArgument(test.input, test.argument) - assertEquals(t, test.expectedToFound, found) - assertEquals(t, test.expectedValue, val) - }) - } -} - func TestIsMountPoint(t *testing.T) { testDir := t.TempDir() mountpointS3MountPath := filepath.Join(testDir, "/var/lib/kubelet/pods/46efe8aa-75d9-4b12-8fdd-0ce0c2cabd99/volumes/kubernetes.io~csi/s3-mp-csi-pv/mount") @@ -391,8 +188,8 @@ func TestIsMountPoint(t *testing.T) { t.Run(name, func(t *testing.T) { mounter := &mounter.SystemdMounter{Mounter: mount.NewFakeMounter(test.procMountsContent)} isMountPoint, err := mounter.IsMountPoint(test.target) - assertEquals(t, test.isMountPoint, isMountPoint) - assertEquals(t, test.expectErr, err != nil) + assert.Equals(t, test.isMountPoint, isMountPoint) + assert.Equals(t, test.expectErr, err != nil) }) } } diff --git a/pkg/driver/node/mounter/user_agent_test.go b/pkg/driver/node/mounter/user_agent_test.go index 5308c92f..cd308ae2 100644 --- a/pkg/driver/node/mounter/user_agent_test.go +++ b/pkg/driver/node/mounter/user_agent_test.go @@ -18,6 +18,8 @@ package mounter import ( "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" ) func TestUserAgent(t *testing.T) { @@ -39,12 +41,12 @@ func TestUserAgent(t *testing.T) { }, "driver authentication source": { k8sVersion: "v1.30.2-eks-db838b0", - authenticationSource: AuthenticationSourceDriver, + authenticationSource: credentialprovider.AuthenticationSourceDriver, result: "s3-csi-driver/ credential-source#driver k8s/v1.30.2-eks-db838b0", }, "pod authentication source": { k8sVersion: "v1.30.2-eks-db838b0", - authenticationSource: AuthenticationSourcePod, + authenticationSource: credentialprovider.AuthenticationSourcePod, result: "s3-csi-driver/ credential-source#pod k8s/v1.30.2-eks-db838b0", }, } diff --git a/pkg/driver/node/node.go b/pkg/driver/node/node.go index b26670e9..83d2fb38 100644 --- a/pkg/driver/node/node.go +++ b/pkg/driver/node/node.go @@ -18,6 +18,7 @@ package node import ( "context" + "errors" "maps" "os" "strings" @@ -25,13 +26,15 @@ import ( "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "k8s.io/apimachinery/pkg/util/sets" "k8s.io/klog/v2" "k8s.io/mount-utils" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/targetpath" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/regionprovider" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" ) const ( @@ -55,30 +58,22 @@ var ( } ) +const stsConfigDocsPage = "https://github.com/awslabs/mountpoint-s3-csi-driver/blob/main/docs/CONFIGURATION.md#configuring-the-sts-region" + // S3NodeServer is the implementation of the csi.NodeServer interface type S3NodeServer struct { - NodeID string - Mounter mounter.Mounter - credentialProvider *mounter.CredentialProvider + nodeID string + mounter mounter.Mounter + credentialProvider *credentialprovider.Provider + regionProvider *regionprovider.Provider + kubernetesVersion string } -func NewS3NodeServer(nodeID string, mounter mounter.Mounter, credentialProvider *mounter.CredentialProvider) *S3NodeServer { - return &S3NodeServer{NodeID: nodeID, Mounter: mounter, credentialProvider: credentialProvider} +func NewS3NodeServer(nodeID string, mounter mounter.Mounter, credentialProvider *credentialprovider.Provider, regionProvider *regionprovider.Provider, kubernetesVersion string) *S3NodeServer { + return &S3NodeServer{nodeID: nodeID, mounter: mounter, credentialProvider: credentialProvider, regionProvider: regionProvider, kubernetesVersion: kubernetesVersion} } func (ns *S3NodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { - volumeCtx := req.GetVolumeContext() - if volumeCtx[volumecontext.AuthenticationSource] == mounter.AuthenticationSourcePod { - podID := volumeCtx[volumecontext.CSIPodUID] - volumeID := req.GetVolumeId() - if podID != "" && volumeID != "" { - err := ns.credentialProvider.CleanupToken(volumeID, podID) - if err != nil { - klog.V(4).Infof("NodeStageVolume: Failed to cleanup token for pod/volume %s/%s: %v", podID, volumeID, err) - } - } - } - return nil, status.Error(codes.Unimplemented, "") } @@ -119,36 +114,40 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl return nil, status.Error(codes.InvalidArgument, "Volume capability not supported") } - mountpointArgs := []string{} + args := []string{} if req.GetReadonly() || volCap.GetAccessMode().GetMode() == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY { - mountpointArgs = append(mountpointArgs, "--read-only") + args = append(args, mountpoint.ArgReadOnly) } - // get the mount(point) options (yaml list) if capMount := volCap.GetMount(); capMount != nil { mountFlags := capMount.GetMountFlags() - for i := range mountFlags { - // trim left and right spaces - // trim spaces in between from multiple spaces to just one i.e. uid 1001 would turn into uid 1001 - // if there is a space between, replace it with an = sign - mountFlags[i] = strings.Replace(strings.Join(strings.Fields(strings.Trim(mountFlags[i], " ")), " "), " ", "=", -1) - // prepend -- if it's not already there - if !strings.HasPrefix(mountFlags[i], "-") { - mountFlags[i] = "--" + mountFlags[i] - } - } - mountpointArgs = compileMountOptions(mountpointArgs, mountFlags) + args = append(args, mountFlags...) } - credentials, err := ns.credentialProvider.Provide(ctx, req.VolumeId, req.VolumeContext, mountpointArgs) + mountpointArgs := mountpoint.ParseArgs(args) + env := envprovider.Provide() + + mountpointArgs, env = ns.moveArgumentsToEnv(mountpointArgs, env) + + credentials, err := ns.credentialProvider.Provide(ctx, volumeCtx) if err != nil { klog.Errorf("NodePublishVolume: failed to provide credentials: %v", err) return nil, err } - klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, mountpointArgs) + mountpointArgs = ns.addUserAgentToArguments(mountpointArgs, credentials) + + // We need to ensure we're using region for STS if Pod-level identity is used. + if credentials.Source() == credentialprovider.AuthenticationSourcePod { + env, err = ns.overrideRegionEnvFromSTSRegion(volumeCtx, mountpointArgs, env) + if err != nil { + return nil, err + } + } + + klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, mountpointArgs.SortedList()) - if err := ns.Mounter.Mount(bucket, target, credentials, mountpointArgs); err != nil { + if err := ns.mounter.Mount(bucket, target, credentials, env, mountpointArgs); err != nil { os.Remove(target) return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) } @@ -157,38 +156,6 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl return &csi.NodePublishVolumeResponse{}, nil } -/** - * Compile mounting options into a singular set - */ -func compileMountOptions(currentOptions []string, newOptions []string) []string { - allMountOptions := sets.NewString() - - for _, currentMountOptions := range currentOptions { - if len(currentMountOptions) > 0 { - allMountOptions.Insert(currentMountOptions) - } - } - - for _, mountOption := range newOptions { - // disallow options that don't make sense in CSI - switch mountOption { - case "--foreground", "-f", "--help", "-h", "--version", "-v": - continue - } - allMountOptions.Insert(mountOption) - } - - return allMountOptions.List() -} - -func getKubeletPath() string { - kubeletPath := os.Getenv("KUBELET_PATH") - if kubeletPath == "" { - return defaultKubeletPath - } - return kubeletPath -} - func (ns *S3NodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { klog.V(4).Infof("NodeUnpublishVolume: called with args %+v", req) @@ -202,7 +169,7 @@ func (ns *S3NodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUn return nil, status.Error(codes.InvalidArgument, "Target path not provided") } - mounted, err := ns.Mounter.IsMountPoint(target) + mounted, err := ns.mounter.IsMountPoint(target) if err != nil && os.IsNotExist(err) { klog.V(4).Infof("NodeUnpublishVolume: target path %s does not exist, skipping unmount", target) return &csi.NodeUnpublishVolumeResponse{}, nil @@ -218,24 +185,12 @@ func (ns *S3NodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUn } klog.V(4).Infof("NodeUnpublishVolume: unmounting %s", target) - err = ns.Mounter.Unmount(target) + err = ns.mounter.Unmount(target) if err != nil { return nil, status.Errorf(codes.Internal, "Could not unmount %q: %v", target, err) } - targetPath, err := targetpath.Parse(target) - if err == nil { - if targetPath.VolumeID != volumeID { - klog.V(4).Infof("NodeUnpublishVolume: Volume ID from parsed target path differs from Volume ID passed: %s (parsed) != %s (passed)", targetPath.VolumeID, volumeID) - } else { - err := ns.credentialProvider.CleanupToken(targetPath.VolumeID, targetPath.PodID) - if err != nil { - klog.V(4).Infof("NodeUnpublishVolume: Failed to cleanup token for pod/volume %s/%s: %v", targetPath.PodID, volumeID, err) - } - } - } else { - klog.V(4).Infof("NodeUnpublishVolume: Failed to parse target path %s: %v", target, err) - } + klog.V(4).Infof("NodeUnpublishVolume: unmounted %s", target) return &csi.NodeUnpublishVolumeResponse{}, nil } @@ -268,7 +223,7 @@ func (ns *S3NodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoReq klog.V(4).Infof("NodeGetInfo: called with args %+v", req) return &csi.NodeGetInfoResponse{ - NodeId: ns.NodeID, + NodeId: ns.nodeID, }, nil } @@ -291,6 +246,39 @@ func (ns *S3NodeServer) isValidVolumeCapabilities(volCaps []*csi.VolumeCapabilit return foundAll } +// moveArgumentsToEnv moves `--aws-max-attempts` from arguments to environment variables if provided. +func (ns *S3NodeServer) moveArgumentsToEnv(args mountpoint.Args, env envprovider.Environment) (mountpoint.Args, envprovider.Environment) { + if maxAttempts, ok := args.Remove(mountpoint.ArgAWSMaxAttempts); ok { + env = append(env, envprovider.Format(envprovider.EnvMaxAttempts, maxAttempts)) + } + + return args, env +} + +// addUserAgentToArguments adds user-agent to Mountpoint arguments. +func (ns *S3NodeServer) addUserAgentToArguments(args mountpoint.Args, credentials credentialprovider.Credentials) mountpoint.Args { + // Remove existing user-agent if provided to ensure we always use the correct user-agent + _, _ = args.Remove(mountpoint.ArgUserAgentPrefix) + userAgent := mounter.UserAgent(credentials.Source(), ns.kubernetesVersion) + args.Insert(envprovider.Format(mountpoint.ArgUserAgentPrefix, userAgent)) + + return args +} + +// overrideRegionEnvFromSTSRegion overrides provided region with the region configured for STS. +func (ns *S3NodeServer) overrideRegionEnvFromSTSRegion(volumeContext map[string]string, args mountpoint.Args, env envprovider.Environment) (envprovider.Environment, error) { + env = envprovider.Remove(env, envprovider.EnvRegion) + region, err := ns.regionProvider.SecurityTokenService(volumeContext, args) + if err != nil { + if errors.Is(err, regionprovider.ErrUnknownRegion) { + return env, status.Errorf(codes.InvalidArgument, "Failed to detect STS AWS Region, please explicitly set the AWS Region, see "+stsConfigDocsPage) + } + return env, err + } + env = append(env, envprovider.Format(envprovider.EnvRegion, region)) + return env, nil +} + // logSafeNodePublishVolumeRequest returns a copy of given `csi.NodePublishVolumeRequest` // with sensitive fields removed. func logSafeNodePublishVolumeRequest(req *csi.NodePublishVolumeRequest) *csi.NodePublishVolumeRequest { @@ -307,3 +295,11 @@ func logSafeNodePublishVolumeRequest(req *csi.NodePublishVolumeRequest) *csi.Nod VolumeContext: safeVolumeContext, } } + +func getKubeletPath() string { + kubeletPath := os.Getenv("KUBELET_PATH") + if kubeletPath == "" { + return defaultKubeletPath + } + return kubeletPath +} diff --git a/pkg/driver/node/node_test.go b/pkg/driver/node/node_test.go index 7db0178e..e8426b63 100644 --- a/pkg/driver/node/node_test.go +++ b/pkg/driver/node/node_test.go @@ -4,19 +4,18 @@ import ( "errors" "fmt" "io/fs" - "os" - "path/filepath" "testing" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" - mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" - "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" csi "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/google/uuid" "golang.org/x/net/context" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter/mocks" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/regionprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" ) type nodeServerTestEnv struct { @@ -29,11 +28,12 @@ func initNodeServerTestEnv(t *testing.T) *nodeServerTestEnv { mockCtl := gomock.NewController(t) defer mockCtl.Finish() mockMounter := mock_driver.NewMockMounter(mockCtl) - credentialProvider := mounter.NewCredentialProvider(nil, t.TempDir(), mounter.RegionFromIMDSOnce) server := node.NewS3NodeServer( "test-nodeID", mockMounter, - credentialProvider, + credentialprovider.New(nil), + regionprovider.New(regionprovider.RegionFromIMDSOnce), + "v1.31.0", ) return &nodeServerTestEnv{ mockCtl: mockCtl, @@ -43,6 +43,9 @@ func initNodeServerTestEnv(t *testing.T) *nodeServerTestEnv { } func TestNodePublishVolume(t *testing.T) { + userAgent := mounter.UserAgent(credentialprovider.AuthenticationSourceDriver, "v1.31.0") + userAgentArg := fmt.Sprintf("%s=%s", mountpoint.ArgUserAgentPrefix, userAgent) + var ( volumeId = "test-volume-id" bucketName = "test-bucket-name" @@ -72,7 +75,7 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Any()) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Any(), gomock.Any()) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -100,7 +103,12 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--read-only"})) + nodeTestEnv.mockMounter.EXPECT().Mount( + gomock.Eq(bucketName), + gomock.Eq(targetPath), + gomock.Any(), + gomock.Any(), + gomock.Eq(mountpoint.ParsedArgs([]string{"--read-only", userAgentArg}))) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -131,7 +139,12 @@ func TestNodePublishVolume(t *testing.T) { Readonly: true, } - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--bar", "--foo", "--read-only", "--test=123"})) + nodeTestEnv.mockMounter.EXPECT().Mount( + gomock.Eq(bucketName), + gomock.Eq(targetPath), + gomock.Any(), + gomock.Any(), + gomock.Eq(mountpoint.ParsedArgs([]string{"--bar", "--foo", "--read-only", "--test=123", userAgentArg}))) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -163,8 +176,12 @@ func TestNodePublishVolume(t *testing.T) { } nodeTestEnv.mockMounter.EXPECT().Mount( - gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), - gomock.Eq([]string{"--read-only", "--test=123"})).Return(nil) + gomock.Eq(bucketName), + gomock.Eq(targetPath), + gomock.Any(), + gomock.Any(), + gomock.Eq(mountpoint.ParsedArgs([]string{"--read-only", "--test=123", userAgentArg})), + ).Return(nil) _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) @@ -291,31 +308,6 @@ func TestNodeUnpublishVolume(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, tc.testFunc) } - - t.Run("Cleaning Service Account Token", func(t *testing.T) { - containerPluginDir := t.TempDir() - credentialProvider := mounter.NewCredentialProvider(nil, containerPluginDir, mounter.RegionFromIMDSOnce) - nodeServer := node.NewS3NodeServer("test-node-id", &dummyMounter{}, credentialProvider) - - podID := uuid.New().String() - volID := "test-vol-id" - - serviceAccountTokenPath := filepath.Join(containerPluginDir, fmt.Sprintf("%s-%s.token", podID, volID)) - _, err := os.Create(serviceAccountTokenPath) - assert.Equals(t, nil, err) - - targetPath := fmt.Sprintf("/var/lib/kubelet/pods/%s/volumes/kubernetes.io~csi/%s/mount", podID, volID) - - _, err = nodeServer.NodeUnpublishVolume(context.Background(), &csi.NodeUnpublishVolumeRequest{ - VolumeId: volID, - TargetPath: targetPath, - }) - assert.Equals(t, nil, err) - - _, err = os.Stat(serviceAccountTokenPath) - assert.Equals(t, cmpopts.AnyError, err) - assert.Equals(t, true, errors.Is(err, fs.ErrNotExist)) - }) } func TestNodeGetCapabilities(t *testing.T) { @@ -335,18 +327,3 @@ func TestNodeGetCapabilities(t *testing.T) { nodeTestEnv.mockCtl.Finish() } - -var _ mounter.Mounter = &dummyMounter{} - -type dummyMounter struct { -} - -func (d *dummyMounter) Mount(bucketName string, target string, credentials *mounter.MountCredentials, options []string) error { - return nil -} -func (d *dummyMounter) Unmount(target string) error { - return nil -} -func (d *dummyMounter) IsMountPoint(target string) (bool, error) { - return true, nil -} diff --git a/pkg/driver/node/regionprovider/imds.go b/pkg/driver/node/regionprovider/imds.go new file mode 100644 index 00000000..205bd347 --- /dev/null +++ b/pkg/driver/node/regionprovider/imds.go @@ -0,0 +1,34 @@ +package regionprovider + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "k8s.io/klog/v2" +) + +// RegionFromIMDSOnce tries to detect AWS region by making a request to IMDS. +// It only makes request to IMDS once and caches the value. +var RegionFromIMDSOnce = sync.OnceValues(func() (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + klog.V(5).Infof("regionprovider: Failed to create config for IMDS client: %v", err) + return "", fmt.Errorf("could not create config for imds client: %w", err) + } + + client := imds.NewFromConfig(cfg) + output, err := client.GetRegion(ctx, &imds.GetRegionInput{}) + if err != nil { + klog.V(5).Infof("regionprovider: Failed to get region from IMDS: %v", err) + return "", fmt.Errorf("failed to get region from imds: %w", err) + } + + return output.Region, nil +}) diff --git a/pkg/driver/node/regionprovider/provider.go b/pkg/driver/node/regionprovider/provider.go new file mode 100644 index 00000000..52cd55a3 --- /dev/null +++ b/pkg/driver/node/regionprovider/provider.go @@ -0,0 +1,69 @@ +// Package regionprovider provides utilities for detecting region by +// looking environment variables, mount options, or calling IMDS. +package regionprovider + +import ( + "errors" + + "k8s.io/klog/v2" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" +) + +// ErrUnknownRegion is the error returned when the region could not be detected. +var ErrUnknownRegion = errors.New("regionprovider: unknown region") + +// A Provider provides methods for detecting regions. +type Provider struct { + regionFromIMDS func() (string, error) +} + +// New creates a new [Provider] by using given [regionFromIMDS]. +func New(regionFromIMDS func() (string, error)) *Provider { + // `regionFromIMDS` is a `sync.OnceValues` and it only makes request to IMDS once, + // this call is basically here to pre-warm the cache of IMDS call. + go func() { + _, _ = regionFromIMDS() + }() + + return &Provider{regionFromIMDS: regionFromIMDS} +} + +// SecurityTokenService tries to detect AWS region to use for STS. +// +// It looks for the following (in-order): +// 1. `stsRegion` passed via volume context +// 2. Region set for S3 bucket via mount options +// 3. `AWS_REGION` or `AWS_DEFAULT_REGION` env variables +// 4. Calling IMDS to detect region +// +// It returns [ErrUnknownRegion] if all of them fails. +func (p *Provider) SecurityTokenService(volumeContext map[string]string, args mountpoint.Args) (string, error) { + region := volumeContext[volumecontext.STSRegion] + if region != "" { + klog.V(5).Infof("regionprovider: Detected STS region %s from volume context", region) + return region, nil + } + + if region, ok := args.Value(mountpoint.ArgRegion); ok { + klog.V(5).Infof("regionprovider: Detected STS region %s from S3 bucket region", region) + return region, nil + } + + region = envprovider.Region() + if region != "" { + klog.V(5).Infof("regionprovider: Detected STS region %s from env variable", region) + return region, nil + } + + // We're ignoring the error here, makes a call to IMDS only once and logs the error in case of error + region, _ = p.regionFromIMDS() + if region != "" { + klog.V(5).Infof("regionprovider: Detected STS region %s from IMDS", region) + return region, nil + } + + return "", ErrUnknownRegion +} diff --git a/pkg/driver/node/regionprovider/provider_test.go b/pkg/driver/node/regionprovider/provider_test.go new file mode 100644 index 00000000..d6556302 --- /dev/null +++ b/pkg/driver/node/regionprovider/provider_test.go @@ -0,0 +1,93 @@ +package regionprovider_test + +import ( + "errors" + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/regionprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext" + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +func TestGettingRegionForSTS(t *testing.T) { + testCases := []struct { + name string + volumeContext map[string]string + args mountpoint.Args + env map[string]string + regionFromIMDS func() (string, error) + want string + wantError error + }{ + { + name: "region from volume context", + volumeContext: map[string]string{volumecontext.STSRegion: "us-west-1"}, + args: mountpoint.ParseArgs(nil), + regionFromIMDS: func() (string, error) { return "", nil }, + want: "us-west-1", + wantError: nil, + }, + { + name: "region from bucket region", + volumeContext: map[string]string{}, + args: mountpoint.ParseArgs([]string{"region us-east-1"}), + regionFromIMDS: func() (string, error) { return "", nil }, + want: "us-east-1", + wantError: nil, + }, + { + name: "region from environment variable", + volumeContext: map[string]string{}, + args: mountpoint.ParseArgs(nil), + env: map[string]string{"AWS_REGION": "us-west-2"}, + regionFromIMDS: func() (string, error) { return "", nil }, + want: "us-west-2", + wantError: nil, + }, + { + name: "region from IMDS", + volumeContext: map[string]string{}, + args: mountpoint.ParseArgs(nil), + regionFromIMDS: func() (string, error) { return "us-east-2", nil }, + want: "us-east-2", + wantError: nil, + }, + { + name: "unknown region", + volumeContext: map[string]string{}, + args: mountpoint.ParseArgs(nil), + regionFromIMDS: func() (string, error) { return "", nil }, + want: "", + wantError: regionprovider.ErrUnknownRegion, + }, + { + name: "IMDS error", + volumeContext: map[string]string{}, + args: mountpoint.ParseArgs(nil), + regionFromIMDS: func() (string, error) { return "", errors.New("IMDS error") }, + want: "", + wantError: regionprovider.ErrUnknownRegion, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + if testCase.env != nil { + for key, val := range testCase.env { + t.Setenv(key, val) + } + } + + provider := regionprovider.New(testCase.regionFromIMDS) + region, err := provider.SecurityTokenService(testCase.volumeContext, testCase.args) + + assert.Equals(t, testCase.want, region) + if testCase.wantError != nil { + assert.Equals(t, err, testCase.wantError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/mountpoint/args.go b/pkg/mountpoint/args.go new file mode 100644 index 00000000..988dc273 --- /dev/null +++ b/pkg/mountpoint/args.go @@ -0,0 +1,119 @@ +package mountpoint + +import ( + "strings" + + "k8s.io/apimachinery/pkg/util/sets" +) + +const ( + ArgForeground = "--foreground" + ArgReadOnly = "--read-only" + ArgAllowOther = "--allow-other" + ArgAllowRoot = "--allow-root" + ArgRegion = "--region" + ArgCache = "--cache" + ArgUserAgentPrefix = "--user-agent-prefix" + ArgAWSMaxAttempts = "--aws-max-attempts" +) + +// An Args represents arguments to be passed to Mountpoint during mount. +type Args struct { + args sets.Set[string] +} + +// ParseArgs parses given list of unnormalized and returns a normalized [Args]. +func ParseArgs(passedArgs []string) Args { + args := sets.New[string]() + + for _, arg := range passedArgs { + // trim left and right spaces + // trim spaces in between from multiple spaces to just one i.e. uid 1001 would turn into uid 1001 + // if there is a space between, replace it with an = sign + arg = strings.Replace(strings.Join(strings.Fields(strings.Trim(arg, " ")), " "), " ", "=", -1) + // prepend -- if it's not already there + if !strings.HasPrefix(arg, "-") { + arg = "--" + arg + } + + // disallow options that don't make sense in CSI + switch arg { + case "--foreground", "-f", "--help", "-h", "--version", "-v": + continue + } + + args.Insert(arg) + } + + return Args{args} +} + +// ParsedArgs creates [Args] from already parsed arguments by [ParseArgs]. +func ParsedArgs(parsedArgs []string) Args { + return Args{args: sets.New(parsedArgs...)} +} + +// Insert inserts given normalized argument to [Args] if not exists. +func (a *Args) Insert(arg string) { + a.args.Insert(arg) +} + +// Value extracts value of given key, it returns extracted value and whether the key was found. +func (a *Args) Value(key string) (string, bool) { + _, val, exists := a.find(key) + return val, exists +} + +// Has returns whether given key exists in [Args]. +func (a *Args) Has(key string) bool { + _, _, exists := a.find(key) + return exists +} + +// Remove removes given key, it returns the key's value and whether the key was found. +func (a *Args) Remove(key string) (string, bool) { + entry, val, exists := a.find(key) + if exists { + a.args.Delete(entry) + } + return val, exists +} + +// SortedList returns ordered list of normalized arguments. +func (a *Args) SortedList() []string { + return sets.List(a.args) +} + +// find tries to find given key from [Args], and returns whole entry, value and whether the key was found. +func (a *Args) find(key string) (string, string, bool) { + key, prefix := a.keysForSearch(key) + + for _, arg := range a.args.UnsortedList() { + if key == arg { + return key, "", true + } + + if strings.HasPrefix(arg, prefix) { + val := strings.SplitN(arg, "=", 2)[1] + return arg, val, true + } + } + + return "", "", false +} + +// keysForSearch returns whole key and a prefix to search for given key in [Args]. +// First one is the whole key to look for without `=` at the end for option-like arguments without any value. +// Second one is a prefix with `=` at the end for arguments with value. +// +// Arguments are normalized to `--key[=value]` in [ParseArgs], here this function also makes sure +// the returned prefixes have the same prefix format for the given key. +func (a *Args) keysForSearch(key string) (string, string) { + prefix := strings.TrimSuffix(key, "=") + + if !strings.HasPrefix(key, "-") { + prefix = "--" + prefix + } + + return prefix, prefix + "=" +} diff --git a/pkg/mountpoint/args_test.go b/pkg/mountpoint/args_test.go new file mode 100644 index 00000000..7af819c0 --- /dev/null +++ b/pkg/mountpoint/args_test.go @@ -0,0 +1,372 @@ +package mountpoint_test + +import ( + "testing" + + "github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint" + "github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert" +) + +func TestParsingMountpointArgs(t *testing.T) { + testCases := []struct { + name string + input []string + want []string + }{ + { + name: "no prefix", + input: []string{ + "allow-delete", + "region us-west-2", + "aws-max-attempts 5", + }, + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--region=us-west-2", + }, + }, + { + name: "with prefix", + input: []string{ + "--cache /tmp/s3-cache", + "--max-cache-size 500", + "--metadata-ttl 3", + }, + want: []string{ + "--cache=/tmp/s3-cache", + "--max-cache-size=500", + "--metadata-ttl=3", + }, + }, + { + name: "with equals but no prefix", + input: []string{ + "allow-delete", + "region=us-west-2", + "sse=aws:kms", + "sse-kms-key-id=arn:aws:kms:us-west-2:012345678900:key/00000000-0000-0000-0000-000000000000", + }, + want: []string{ + "--allow-delete", + "--region=us-west-2", + "--sse-kms-key-id=arn:aws:kms:us-west-2:012345678900:key/00000000-0000-0000-0000-000000000000", + "--sse=aws:kms", + }, + }, + { + name: "with equals and prefix", + input: []string{ + "--allow-other", + "--uid=1000", + "--gid=2000", + }, + want: []string{ + "--allow-other", + "--gid=2000", + "--uid=1000", + }, + }, + { + name: "with multiple spaces", + input: []string{ + "--allow-other", + "--uid 1000", + "--gid 2000", + }, + want: []string{ + "--allow-other", + "--gid=2000", + "--uid=1000", + }, + }, + { + name: "with single dash prefix", + input: []string{ + "-d", + "-l logs/", + }, + want: []string{ + "-d", + "-l=logs/", + }, + }, + { + name: "mixed prefix and equal signs", + input: []string{ + "--allow-delete", + "read-only", + "--cache=/tmp/s3-cache", + "--region us-east-1", + "prefix some-s3-prefix/", + "-d", + "-l=logs/", + }, + want: []string{ + "--allow-delete", + "--cache=/tmp/s3-cache", + "--prefix=some-s3-prefix/", + "--read-only", + "--region=us-east-1", + "-d", + "-l=logs/", + }, + }, + { + name: "with duplicated options", + input: []string{ + "--allow-other", + "--read-only", + "read-only", + "--allow-other", + }, + want: []string{ + "--allow-other", + "--read-only", + }, + }, + { + name: "with unsupported options", + input: []string{ + "--allow-other", + "--read-only", + "--foreground", "-f", + "--help", "-h", + "--version", "-v", + }, + want: []string{ + "--allow-other", + "--read-only", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.input) + assert.Equals(t, testCase.want, args.SortedList()) + }) + } +} + +func TestInsertingNewArgsToMountpointArgs(t *testing.T) { + testCases := []struct { + name string + existingArgs []string + newArg string + want []string + }{ + { + name: "new arg", + existingArgs: []string{ + "allow-delete", + "region us-west-2", + "aws-max-attempts 5", + }, + newArg: mountpoint.ArgReadOnly, + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--read-only", + "--region=us-west-2", + }, + }, + { + name: "existing arg", + existingArgs: []string{ + "allow-delete", + "read-only", + "region us-west-2", + "aws-max-attempts 5", + }, + newArg: mountpoint.ArgReadOnly, + want: []string{ + "--allow-delete", + "--aws-max-attempts=5", + "--read-only", + "--region=us-west-2", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.existingArgs) + args.Insert(testCase.newArg) + assert.Equals(t, testCase.want, args.SortedList()) + }) + } +} + +func TestExtractingAnArgumentsValueFromMountpointArgs(t *testing.T) { + testCases := []struct { + name string + args []string + argToExtract string + want string + exists bool + }{ + { + name: "existing argument", + args: []string{ + "cache /tmp/s3-cache", + "region us-west-2", + "aws-max-attempts 5", + }, + argToExtract: mountpoint.ArgCache, + want: "/tmp/s3-cache", + exists: true, + }, + { + name: "existing argument with equal", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts 5", + }, + argToExtract: mountpoint.ArgRegion, + want: "us-west-2", + exists: true, + }, + { + name: "existing argument queried without prefix", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts 5", + }, + argToExtract: "region", + want: "us-west-2", + exists: true, + }, + { + name: "existing argument queried without equals", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts 5", + }, + argToExtract: "--region", + want: "us-west-2", + exists: true, + }, + { + name: "non-existent argument", + args: []string{ + "cache /tmp/s3-cache", + "aws-max-attempts 5", + }, + argToExtract: mountpoint.ArgRegion, + want: "", + exists: false, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.args) + gotValue, gotExists := args.Value(testCase.argToExtract) + assert.Equals(t, testCase.want, gotValue) + assert.Equals(t, testCase.exists, gotExists) + }) + } +} + +func TestRemovingAnArgumentFromMountpointArgs(t *testing.T) { + testCases := []struct { + name string + args []string + argToRemove string + want string + argsAfter []string + exists bool + }{ + { + name: "existing argument", + args: []string{ + "user-agent-prefix foo/bar", + "cache /tmp/s3-cache", + "region us-west-2", + "aws-max-attempts 5", + }, + argToRemove: mountpoint.ArgUserAgentPrefix, + want: "foo/bar", + exists: true, + argsAfter: []string{ + "--aws-max-attempts=5", + "--cache=/tmp/s3-cache", + "--region=us-west-2", + }, + }, + { + name: "existing argument with equal", + args: []string{ + "cache /tmp/s3-cache", + "region=us-west-2", + "aws-max-attempts=5", + }, + argToRemove: mountpoint.ArgAWSMaxAttempts, + want: "5", + exists: true, + argsAfter: []string{ + "--cache=/tmp/s3-cache", + "--region=us-west-2", + }, + }, + { + name: "non-existent argument", + args: []string{ + "cache /tmp/s3-cache", + "aws-max-attempts 5", + }, + argToRemove: mountpoint.ArgRegion, + want: "", + exists: false, + argsAfter: []string{ + "--aws-max-attempts=5", + "--cache=/tmp/s3-cache", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + args := mountpoint.ParseArgs(testCase.args) + gotValue, gotExists := args.Remove(testCase.argToRemove) + assert.Equals(t, testCase.want, gotValue) + assert.Equals(t, testCase.exists, gotExists) + assert.Equals(t, testCase.argsAfter, args.SortedList()) + }) + } +} + +func TestQueryingExistenceOfAKeyInMountpointArgs(t *testing.T) { + args := mountpoint.ParseArgs([]string{ + "--allow-other", + "--cache /tmp/s3-cache", + "read-only", + }) + + assert.Equals(t, true, args.Has(mountpoint.ArgAllowOther)) + assert.Equals(t, true, args.Has(mountpoint.ArgCache)) + assert.Equals(t, true, args.Has(mountpoint.ArgReadOnly)) + assert.Equals(t, false, args.Has(mountpoint.ArgAllowRoot)) + assert.Equals(t, false, args.Has(mountpoint.ArgRegion)) +} + +func TestCreatingMountpointArgsFromAlreadyParsedArgs(t *testing.T) { + args := mountpoint.ParseArgs([]string{ + "--allow-other", + "--cache /tmp/s3-cache", + "read-only", + }) + args.Insert("--user-agent-prefix=s3-csi-driver/1.11.0 credential-source#pod k8s/v1.30.6-eks-7f9249a") + + want := []string{ + "--allow-other", + "--cache=/tmp/s3-cache", + "--read-only", + "--user-agent-prefix=s3-csi-driver/1.11.0 credential-source#pod k8s/v1.30.6-eks-7f9249a", + } + + assert.Equals(t, want, args.SortedList()) + + parsedArgs := mountpoint.ParsedArgs(args.SortedList()) + assert.Equals(t, want, parsedArgs.SortedList()) +} diff --git a/tests/sanity/sanity_test.go b/tests/sanity/sanity_test.go index 2ec9e8eb..1039861e 100644 --- a/tests/sanity/sanity_test.go +++ b/tests/sanity/sanity_test.go @@ -29,7 +29,8 @@ import ( "github.com/awslabs/aws-s3-csi-driver/pkg/driver" "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node" - "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider" + "github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/regionprovider" ) const ( @@ -72,8 +73,10 @@ var _ = BeforeSuite(func() { NodeID: "fake_id", NodeServer: node.NewS3NodeServer( "fake_id", - &mounter.FakeMounter{}, - mounter.NewCredentialProvider(nil, GinkgoT().TempDir(), mounter.RegionFromIMDSOnce), + nil, + credentialprovider.New(nil), + regionprovider.New(regionprovider.RegionFromIMDSOnce), + "v1.31.0", ), } go func() {