Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split the functionality in node/mounter into smaller packages #328

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
74 changes: 74 additions & 0 deletions pkg/driver/node/envprovider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Package envprovider provides utilities for accessing environment variables to pass Mountpoint.
package envprovider

import (
"fmt"
"os"
"slices"
"strings"
)

const (
EnvRegion = "AWS_REGION"
EnvDefaultRegion = "AWS_DEFAULT_REGION"
EnvSTSRegionalEndpoints = "AWS_STS_REGIONAL_ENDPOINTS"
EnvMaxAttempts = "AWS_MAX_ATTEMPTS"
EnvProfile = "AWS_PROFILE"
EnvConfigFile = "AWS_CONFIG_FILE"
EnvSharedCredentialsFile = "AWS_SHARED_CREDENTIALS_FILE"
EnvRoleARN = "AWS_ROLE_ARN"
EnvWebIdentityTokenFile = "AWS_WEB_IDENTITY_TOKEN_FILE"
EnvEC2MetadataDisabled = "AWS_EC2_METADATA_DISABLED"

EnvMountpointCacheKey = "UNSTABLE_MOUNTPOINT_CACHE_KEY"
)

// An Environment represents a list of environment variables.
type Environment = []string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a dict or list of tuples internally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think this should be map[string]string. Making it a map would cause lots of diffs in the tests though, they were already []string and due to way type aliases work in Go, we can just use []string in place of Environment. But if we make it a map, then we would need to update all call sites to use a proper type, which then would make this PR even bigger. I think it's best to do that as a follow-up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me


// envAllowlist is the list of environment variables to pass-by by default.
// If any of these set, it will be returned as-is in [Provide].
var envAllowlist = []string{
EnvRegion,
EnvDefaultRegion,
EnvSTSRegionalEndpoints,
}

// Region returns detected region from environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.
// It returns an empty string if both is unset.
func Region() string {
region := os.Getenv(EnvRegion)
if region != "" {
return region
}
return os.Getenv(EnvDefaultRegion)
}

// Provide returns list of environment variables to pass Mountpoint.
func Provide() Environment {
environment := Environment{}
for _, key := range envAllowlist {
val := os.Getenv(key)
if val != "" {
environment = append(environment, Format(key, val))
}
}
return environment
}

// Format formats given key and value to be used as an environment variable.
func Format(key, value string) string {
return fmt.Sprintf("%s=%s", key, value)
}

// Remove removes environment variable with given `key` from given environment variables `env`.
// It returns updated environment variables.
func Remove(env Environment, key string) Environment {
prefix := key
if !strings.HasSuffix(key, "=") {
prefix = prefix + "="
}
return slices.DeleteFunc(env, func(k string) bool {
return strings.HasPrefix(k, prefix)
})
}
149 changes: 149 additions & 0 deletions pkg/driver/node/envprovider/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package envprovider_test

import (
"testing"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider"
"github.com/awslabs/aws-s3-csi-driver/pkg/util/testutil/assert"
)

func TestGettingRegion(t *testing.T) {
testCases := []struct {
name string
envRegion string
envDefaultRegion string
want string
}{
{
name: "both region envs are set",
envRegion: "us-west-1",
envDefaultRegion: "us-east-1",
want: "us-west-1",
},
{
name: "only default region env is set",
envRegion: "",
envDefaultRegion: "us-east-1",
want: "us-east-1",
},
{
name: "no region env is set",
envRegion: "",
envDefaultRegion: "",
want: "",
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("AWS_REGION", testCase.envRegion)
t.Setenv("AWS_DEFAULT_REGION", testCase.envDefaultRegion)
assert.Equals(t, testCase.want, envprovider.Region())
})
}
}

func TestProvidingEnvironmentVariables(t *testing.T) {
testCases := []struct {
name string
env map[string]string
want []string
}{
{
name: "no env vars set",
env: map[string]string{},
want: []string{},
},
{
name: "some allowed env vars set",
env: map[string]string{
"AWS_REGION": "us-west-1",
"AWS_DEFAULT_REGION": "us-east-1",
"AWS_STS_REGIONAL_ENDPOINTS": "regional",
"AWS_MAX_ATTEMPTS": "10",
},
want: []string{
"AWS_REGION=us-west-1",
"AWS_DEFAULT_REGION=us-east-1",
"AWS_STS_REGIONAL_ENDPOINTS=regional",
},
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
for k, v := range testCase.env {
t.Setenv(k, v)
}
assert.Equals(t, testCase.want, envprovider.Provide())
})
}
}

func TestFormattingEnvironmentVariable(t *testing.T) {
testCases := []struct {
name string
key string
value string
want string
}{
{
name: "region",
key: "AWS_REGION",
value: "us-west-1",
want: "AWS_REGION=us-west-1",
},
{
name: "role arn",
key: "AWS_ROLE_ARN",
value: "arn:aws:iam::account:role/csi-driver-role-name",
want: "AWS_ROLE_ARN=arn:aws:iam::account:role/csi-driver-role-name",
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
assert.Equals(t, testCase.want, envprovider.Format(testCase.key, testCase.value))
})
}
}

func TestRemovingAKeyFromListOfEnvironmentVariables(t *testing.T) {
testCases := []struct {
name string
env envprovider.Environment
key string
want envprovider.Environment
}{
{
name: "empty environment",
env: envprovider.Environment{},
key: "AWS_REGION",
want: envprovider.Environment{},
},
{
name: "remove existing key",
env: envprovider.Environment{"AWS_REGION=us-west-1", "AWS_DEFAULT_REGION=us-east-1"},
key: "AWS_REGION",
want: envprovider.Environment{"AWS_DEFAULT_REGION=us-east-1"},
},
{
name: "remove existing key with equals sign",
env: envprovider.Environment{"AWS_REGION=us-west-1", "AWS_DEFAULT_REGION=us-east-1"},
key: "AWS_REGION=",
want: envprovider.Environment{"AWS_DEFAULT_REGION=us-east-1"},
},
{
name: "remove non-existing key",
env: envprovider.Environment{"AWS_REGION=us-west-1", "AWS_DEFAULT_REGION=us-east-1"},
key: "AWS_MAX_ATTEMPTS",
want: envprovider.Environment{"AWS_REGION=us-west-1", "AWS_DEFAULT_REGION=us-east-1"},
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
assert.Equals(t, testCase.want, envprovider.Remove(testCase.env, testCase.key))
})
}
}
34 changes: 34 additions & 0 deletions pkg/driver/node/regionprovider/imds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package regionprovider
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has any code here changed? The removed code isn't in this commit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no code change (other than function name) in this package (regionprovider/{imds, provider, provider_test}.go). It just moved out of now removed pkg/driver/node/mounter/credential_provider.go


import (
"context"
"fmt"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"k8s.io/klog/v2"
)

// RegionFromIMDSOnce tries to detect AWS region by making a request to IMDS.
// It only makes request to IMDS once and caches the value.
var RegionFromIMDSOnce = sync.OnceValues(func() (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
klog.V(5).Infof("regionprovider: Failed to create config for IMDS client: %v", err)
return "", fmt.Errorf("could not create config for imds client: %w", err)
}

client := imds.NewFromConfig(cfg)
output, err := client.GetRegion(ctx, &imds.GetRegionInput{})
if err != nil {
klog.V(5).Infof("regionprovider: Failed to get region from IMDS: %v", err)
return "", fmt.Errorf("failed to get region from imds: %w", err)
}

return output.Region, nil
})
69 changes: 69 additions & 0 deletions pkg/driver/node/regionprovider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Package regionprovider provides utilities for detecting region by
// looking environment variables, mount options, or calling IMDS.
package regionprovider

import (
"errors"

"k8s.io/klog/v2"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/envprovider"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext"
"github.com/awslabs/aws-s3-csi-driver/pkg/mountpoint"
)

// ErrUnknownRegion is the error returned when the region could not be detected.
var ErrUnknownRegion = errors.New("regionprovider: unknown region")

// A Provider provides methods for detecting regions.
type Provider struct {
regionFromIMDS func() (string, error)
}

// New creates a new [Provider] by using given [regionFromIMDS].
func New(regionFromIMDS func() (string, error)) *Provider {
// `regionFromIMDS` is a `sync.OnceValues` and it only makes request to IMDS once,
// this call is basically here to pre-warm the cache of IMDS call.
go func() {
_, _ = regionFromIMDS()
}()

return &Provider{regionFromIMDS: regionFromIMDS}
}

// SecurityTokenService tries to detect AWS region to use for STS.
//
// It looks for the following (in-order):
// 1. `stsRegion` passed via volume context
// 2. Region set for S3 bucket via mount options
// 3. `AWS_REGION` or `AWS_DEFAULT_REGION` env variables
// 4. Calling IMDS to detect region
//
// It returns [ErrUnknownRegion] if all of them fails.
func (p *Provider) SecurityTokenService(volumeContext map[string]string, args mountpoint.Args) (string, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name here doesn't describe the functionality. From what I understand, it either returns the region to use for STS or returns an error. Perhaps getSTSRegion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the usage of this function is like: regionProvider.SecurityTokenService(volumeCtx, args). That's why I didn't add region to name as we'd duplicate the information, and not sure if get adds any value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think get adds value as otherwise it's not obvious what we're actually doing with STS. I'd be happy with regionProvider.getSTS(...)

region := volumeContext[volumecontext.STSRegion]
if region != "" {
klog.V(5).Infof("regionprovider: Detected STS region %s from volume context", region)
return region, nil
}

if region, ok := args.Value(mountpoint.ArgRegion); ok {
klog.V(5).Infof("regionprovider: Detected STS region %s from S3 bucket region", region)
return region, nil
}

region = envprovider.Region()
if region != "" {
klog.V(5).Infof("regionprovider: Detected STS region %s from env variable", region)
return region, nil
}

// We're ignoring the error here, makes a call to IMDS only once and logs the error in case of error
region, _ = p.regionFromIMDS()
if region != "" {
klog.V(5).Infof("regionprovider: Detected STS region %s from IMDS", region)
return region, nil
}

return "", ErrUnknownRegion
}
Loading