Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split the functionality in node/mounter into smaller packages #328

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
45 changes: 11 additions & 34 deletions pkg/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ import (
"fmt"
"net"
"os"
"time"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/version"
"github.com/awslabs/aws-s3-csi-driver/pkg/util"
"github.com/container-storage-interface/spec/lib/go/csi"
"google.golang.org/grpc"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/klog/v2"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/regionprovider"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/version"
)

const (
driverName = "s3.csi.aws.com"
webIdentityTokenEnv = "AWS_WEB_IDENTITY_TOKEN_FILE"
driverName = "s3.csi.aws.com"

grpcServerMaxReceiveMessageSize = 1024 * 1024 * 2 // 2MB

Expand Down Expand Up @@ -74,13 +74,14 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error
klog.Infof("Driver version: %v, Git commit: %v, build date: %v, nodeID: %v, mount-s3 version: %v, kubernetes version: %v",
version.DriverVersion, version.GitCommit, version.BuildDate, nodeID, mpVersion, kubernetesVersion)

systemd_mounter, err := mounter.NewSystemdMounter(mpVersion, kubernetesVersion)
systemdMounter, err := mounter.NewSystemdMounter(mpVersion)
if err != nil {
klog.Fatalln(err)
}

credentialProvider := mounter.NewCredentialProvider(clientset.CoreV1(), containerPluginDir, mounter.RegionFromIMDSOnce)
nodeServer := node.NewS3NodeServer(nodeID, systemd_mounter, credentialProvider)
credentialProvider := credentialprovider.New(clientset.CoreV1())
regionProvider := regionprovider.New(regionprovider.RegionFromIMDSOnce)
nodeServer := node.NewS3NodeServer(nodeID, systemdMounter, credentialProvider, regionProvider, kubernetesVersion)

return &Driver{
Endpoint: endpoint,
Expand All @@ -90,14 +91,6 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) (*Driver, error
}

func (d *Driver) Run() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tokenFile := os.Getenv(webIdentityTokenEnv)
if tokenFile != "" {
klog.Infof("Found AWS_WEB_IDENTITY_TOKEN_FILE, syncing token")
go tokenFileTender(ctx, tokenFile, "/csi/token")
}
Comment on lines -93 to -99
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic has been moved to credentials_sts_web_identity.go. Since we enabled requiresRepublish as part of Pod-level identity support, Kubernetes will call NodePublishVolume periodically to update existing service account tokens before they expire, and the credential provider will be called as part of this method. Another reason for this change is that, this assumes a single location for service account tokens for Driver-level identity, but with containerization this won't be the case. See the note regarding credential provider in the original PR description for more context.


scheme, addr, err := ParseEndpoint(d.Endpoint)
if err != nil {
return err
Expand Down Expand Up @@ -150,22 +143,6 @@ func (d *Driver) Stop() {
d.Srv.Stop()
}

func tokenFileTender(ctx context.Context, sourcePath string, destPath string) {
for {
timer := time.After(10 * time.Second)
err := util.ReplaceFile(destPath, sourcePath, 0600)
if err != nil {
klog.Infof("Failed to sync AWS web token file: %v", err)
}
select {
case <-timer:
continue
case <-ctx.Done():
return
}
}
}

func kubernetesVersion(clientset *kubernetes.Clientset) (string, error) {
version, err := clientset.ServerVersion()
if err != nil {
Expand Down
112 changes: 0 additions & 112 deletions pkg/driver/node/awsprofile/aws_profile_test.go

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,51 @@ import (
"path/filepath"
"strings"
"unicode"

"github.com/google/renameio"
)

const (
awsProfileName = "s3-csi"
awsProfileConfigFilename = "s3-csi-config"
awsProfileCredentialsFilename = "s3-csi-credentials"
awsProfileFilePerm = fs.FileMode(0400) // only owner readable
)

// ErrInvalidCredentials is returned when given AWS Credentials contains invalid characters.
var ErrInvalidCredentials = errors.New("aws-profile: Invalid AWS Credentials")

// An AWSProfile represents an AWS profile with it's credentials and config files.
type AWSProfile struct {
Name string
ConfigPath string
CredentialsPath string
Name string
ConfigFilename string
CredentialsFilename string
}

// CreateAWSProfile creates an AWS Profile with credentials and config files from given credentials.
// Created credentials and config files can be clean up with `CleanupAWSProfile`.
func CreateAWSProfile(basepath string, accessKeyID string, secretAccessKey string, sessionToken string) (AWSProfile, error) {
func CreateAWSProfile(basepath string, accessKeyID string, secretAccessKey string, sessionToken string, filePerm fs.FileMode) (AWSProfile, error) {
if !isValidCredential(accessKeyID) || !isValidCredential(secretAccessKey) || !isValidCredential(sessionToken) {
return AWSProfile{}, ErrInvalidCredentials
}

name := awsProfileName

configPath := filepath.Join(basepath, awsProfileConfigFilename)
err := writeAWSProfileFile(configPath, configFileContents(name))
err := writeAWSProfileFile(configPath, configFileContents(name), filePerm)
if err != nil {
return AWSProfile{}, fmt.Errorf("aws-profile: Failed to create config file %s: %v", configPath, err)
}

credentialsPath := filepath.Join(basepath, awsProfileCredentialsFilename)
err = writeAWSProfileFile(credentialsPath, credentialsFileContents(name, accessKeyID, secretAccessKey, sessionToken))
err = writeAWSProfileFile(credentialsPath, credentialsFileContents(name, accessKeyID, secretAccessKey, sessionToken), filePerm)
if err != nil {
return AWSProfile{}, fmt.Errorf("aws-profile: Failed to create credentials file %s: %v", credentialsPath, err)
}

return AWSProfile{
Name: name,
ConfigPath: configPath,
CredentialsPath: credentialsPath,
Name: name,
ConfigFilename: awsProfileConfigFilename,
CredentialsFilename: awsProfileCredentialsFilename,
}, nil
}

Expand All @@ -75,14 +76,8 @@ func CleanupAWSProfile(basepath string) error {
return nil
}

func writeAWSProfileFile(path string, content string) error {
err := os.WriteFile(path, []byte(content), awsProfileFilePerm)
if err != nil {
return err
}
// If the given file exists, `os.WriteFile` just truncates it without changing it's permissions,
// so we need to ensure it always has the correct permissions.
return os.Chmod(path, awsProfileFilePerm)
Comment on lines -83 to -85
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func writeAWSProfileFile(path string, content string, filePerm os.FileMode) error {
return renameio.WriteFile(path, []byte(content), filePerm)
}

func credentialsFileContents(profile string, accessKeyID string, secretAccessKey string, sessionToken string) string {
Expand Down
100 changes: 100 additions & 0 deletions pkg/driver/node/credentialprovider/awsprofile/aws_profile_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package awsprofile_test

import (
"errors"
"io/fs"
"os"
"path/filepath"
"testing"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/credentialprovider/awsprofile/awsprofiletest"
"github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert"
)

const testAccessKeyId = "test-access-key-id"
const testSecretAccessKey = "test-secret-access-key"
const testSessionToken = "test-session-token"
const testFilePerm = fs.FileMode(0600)

func TestCreatingAWSProfile(t *testing.T) {
t.Run("create config and credentials files", func(t *testing.T) {
basepath := t.TempDir()
profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, testSessionToken, testFilePerm)
assert.NoError(t, err)
assertCredentialsFromAWSProfile(t, basepath, profile, testAccessKeyId, testSecretAccessKey, testSessionToken)
})

t.Run("create config and credentials files with empty session token", func(t *testing.T) {
basepath := t.TempDir()
profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, "", testFilePerm)
assert.NoError(t, err)
assertCredentialsFromAWSProfile(t, basepath, profile, testAccessKeyId, testSecretAccessKey, "")
})

t.Run("ensure config and credentials files are created with correct permissions", func(t *testing.T) {
basepath := t.TempDir()
profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, testSessionToken, testFilePerm)
assert.NoError(t, err)
assertCredentialsFromAWSProfile(t, basepath, profile, testAccessKeyId, testSecretAccessKey, testSessionToken)

configStat, err := os.Stat(filepath.Join(basepath, profile.ConfigFilename))
assert.NoError(t, err)
assert.Equals(t, testFilePerm, configStat.Mode())

credentialsStat, err := os.Stat(filepath.Join(basepath, profile.CredentialsFilename))
assert.NoError(t, err)
assert.Equals(t, testFilePerm, credentialsStat.Mode())
})

t.Run("fail if credentials contains non-ascii characters", func(t *testing.T) {
t.Run("access key ID", func(t *testing.T) {
_, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId+"\n\t\r credential_process=exit", testSecretAccessKey, testSessionToken, testFilePerm)
assert.Equals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials))
})
t.Run("secret access key", func(t *testing.T) {
_, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey+"\n", testSessionToken, testFilePerm)
assert.Equals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials))
})
t.Run("session token", func(t *testing.T) {
_, err := awsprofile.CreateAWSProfile(t.TempDir(), testAccessKeyId, testSecretAccessKey, testSessionToken+"\n\r", testFilePerm)
assert.Equals(t, true, errors.Is(err, awsprofile.ErrInvalidCredentials))
})
})
}

func TestCleaningUpAWSProfile(t *testing.T) {
t.Run("clean config and credentials files", func(t *testing.T) {
basepath := t.TempDir()

profile, err := awsprofile.CreateAWSProfile(basepath, testAccessKeyId, testSecretAccessKey, testSessionToken, testFilePerm)
assert.NoError(t, err)
assertCredentialsFromAWSProfile(t, basepath, profile, testAccessKeyId, testSecretAccessKey, testSessionToken)

err = awsprofile.CleanupAWSProfile(basepath)
assert.NoError(t, err)

_, err = os.Stat(filepath.Join(basepath, profile.ConfigFilename))
assert.Equals(t, true, errors.Is(err, fs.ErrNotExist))

_, err = os.Stat(filepath.Join(basepath, profile.CredentialsFilename))
assert.Equals(t, true, errors.Is(err, fs.ErrNotExist))
})

t.Run("cleaning non-existent config and credentials files should not be an error", func(t *testing.T) {
err := awsprofile.CleanupAWSProfile(t.TempDir())
assert.NoError(t, err)
})
}

func assertCredentialsFromAWSProfile(t *testing.T, basepath string, profile awsprofile.AWSProfile, accessKeyID string, secretAccessKey string, sessionToken string) {
awsprofiletest.AssertCredentialsFromAWSProfile(
t,
profile.Name,
filepath.Join(basepath, profile.ConfigFilename),
filepath.Join(basepath, profile.CredentialsFilename),
accessKeyID,
secretAccessKey,
sessionToken,
)
}
Loading
Loading