Skip to content

Commit

Permalink
Add volumecontext package for accessing volume context from CSI (#340)
Browse files Browse the repository at this point in the history
Splitted out of
#328.

This might not make sense since it's moved out of its original context,
but the original motivation was that these constants was spread out to
different packages and that was causing circular dependencies.

For example, if you need to use `VolumeCtxBucketName` in
`pkg/driver/node/mounter` package, you'd need to import
`pkg/driver/node` package which would cause a circular dependency as
`pkg/driver/node` imports `pkg/driver/node/mounter`. In situations like
that, it's best to extract common things into a leaf package that
doesn't import anything, which is what this PR does with
`pkg/driver/node/volumecontext` package.

It also makes finding these volume context keys easier by placing them
into a singular place.

---

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Signed-off-by: Burak Varlı <[email protected]>
  • Loading branch information
unexge authored Jan 14, 2025
1 parent 5c67b96 commit 8c65618
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
43 changes: 18 additions & 25 deletions pkg/driver/node/mounter/credential_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
k8sv1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/klog/v2"
k8sstrings "k8s.io/utils/strings"

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext"
)

const hostPluginDirEnv = "HOST_PLUGIN_DIR"
Expand All @@ -34,15 +36,6 @@ const (
AuthenticationSourcePod AuthenticationSource = "pod"
)

const (
VolumeCtxAuthenticationSource = "authenticationSource"
VolumeCtxSTSRegion = "stsRegion"
VolumeCtxServiceAccountName = "csi.storage.k8s.io/serviceAccount.name"
VolumeCtxServiceAccountTokens = "csi.storage.k8s.io/serviceAccount.tokens"
VolumeCtxPodNamespace = "csi.storage.k8s.io/pod.namespace"
VolumeCtxPodUID = "csi.storage.k8s.io/pod.uid"
)

const (
// This is to ensure only owner/group can read the file and no one else.
serviceAccountTokenPerm = 0440
Expand Down Expand Up @@ -91,15 +84,15 @@ func (c *CredentialProvider) CleanupToken(volumeID string, podID string) error {

// Provide provides mount credentials for given volume and volume context.
// Depending on the configuration, it either returns driver-level or pod-level credentials.
func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeContext map[string]string, mountpointArgs []string) (*MountCredentials, error) {
if volumeContext == nil {
func (c *CredentialProvider) Provide(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) {
if volumeCtx == nil {
return nil, status.Error(codes.InvalidArgument, "Missing volume context")
}

authenticationSource := volumeContext[VolumeCtxAuthenticationSource]
authenticationSource := volumeCtx[volumecontext.AuthenticationSource]
switch authenticationSource {
case AuthenticationSourcePod:
return c.provideFromPod(ctx, volumeID, volumeContext, mountpointArgs)
return c.provideFromPod(ctx, volumeID, volumeCtx, mountpointArgs)
case AuthenticationSourceUnspecified, AuthenticationSourceDriver:
return c.provideFromDriver()
default:
Expand All @@ -126,10 +119,10 @@ func (c *CredentialProvider) provideFromDriver() (*MountCredentials, error) {
}, nil
}

func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeContext map[string]string, mountpointArgs []string) (*MountCredentials, error) {
func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string, volumeCtx map[string]string, mountpointArgs []string) (*MountCredentials, error) {
klog.V(4).Infof("NodePublishVolume: Using pod identity")

tokensJson := volumeContext[VolumeCtxServiceAccountTokens]
tokensJson := volumeCtx[volumecontext.CSIServiceAccountTokens]
if tokensJson == "" {
klog.Error("`authenticationSource` configured to `pod` but no service account tokens are received. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage)
return nil, status.Error(codes.InvalidArgument, "Missing service account tokens")
Expand All @@ -146,12 +139,12 @@ func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string
return nil, status.Errorf(codes.InvalidArgument, "Missing service account token for %s", serviceAccountTokenAudienceSTS)
}

awsRoleARN, err := c.findPodServiceAccountRole(ctx, volumeContext)
awsRoleARN, err := c.findPodServiceAccountRole(ctx, volumeCtx)
if err != nil {
return nil, err
}

region, err := c.stsRegion(volumeContext, mountpointArgs)
region, err := c.stsRegion(volumeCtx, mountpointArgs)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "Failed to detect STS AWS Region, please explicitly set the AWS Region, see "+stsConfigDocsPage)
}
Expand All @@ -161,7 +154,7 @@ func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string
defaultRegion = region
}

podID := volumeContext[VolumeCtxPodUID]
podID := volumeCtx[volumecontext.CSIPodUID]
if podID == "" {
return nil, status.Error(codes.InvalidArgument, "Missing Pod info. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage)
}
Expand All @@ -174,8 +167,8 @@ func (c *CredentialProvider) provideFromPod(ctx context.Context, volumeID string
hostPluginDir := hostPluginDirWithDefault()
hostTokenPath := path.Join(hostPluginDir, c.tokenFilename(podID, volumeID))

podNamespace := volumeContext[VolumeCtxPodNamespace]
podServiceAccount := volumeContext[VolumeCtxServiceAccountName]
podNamespace := volumeCtx[volumecontext.CSIPodNamespace]
podServiceAccount := volumeCtx[volumecontext.CSIServiceAccountName]
cacheKey := podNamespace + "/" + podServiceAccount

return &MountCredentials{
Expand Down Expand Up @@ -221,9 +214,9 @@ func (c *CredentialProvider) tokenFilename(podID string, volumeID string) string
return filename.String()
}

func (c *CredentialProvider) findPodServiceAccountRole(ctx context.Context, volumeContext map[string]string) (string, error) {
podNamespace := volumeContext[VolumeCtxPodNamespace]
podServiceAccount := volumeContext[VolumeCtxServiceAccountName]
func (c *CredentialProvider) findPodServiceAccountRole(ctx context.Context, volumeCtx map[string]string) (string, error) {
podNamespace := volumeCtx[volumecontext.CSIPodNamespace]
podServiceAccount := volumeCtx[volumecontext.CSIServiceAccountName]
if podNamespace == "" || podServiceAccount == "" {
klog.Error("`authenticationSource` configured to `pod` but no pod info found. Please make sure to enable `podInfoOnMountCompat`, see " + podLevelCredentialsDocsPage)
return "", status.Error(codes.InvalidArgument, "Missing Pod info. Please make sure to enable `podInfoOnMountCompat`, see "+podLevelCredentialsDocsPage)
Expand Down Expand Up @@ -252,8 +245,8 @@ func (c *CredentialProvider) findPodServiceAccountRole(ctx context.Context, volu
// 4. Calling IMDS to detect region
//
// It returns an error if all of them fails.
func (c *CredentialProvider) stsRegion(volumeContext map[string]string, mountpointArgs []string) (string, error) {
region := volumeContext[VolumeCtxSTSRegion]
func (c *CredentialProvider) stsRegion(volumeCtx map[string]string, mountpointArgs []string) (string, error) {
region := volumeCtx[volumecontext.STSRegion]
if region != "" {
klog.V(5).Infof("NodePublishVolume: Pod-level: Detected STS region %s from volume context", region)
return region, nil
Expand Down
16 changes: 8 additions & 8 deletions pkg/driver/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ import (

"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/mounter"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/targetpath"
"github.com/awslabs/aws-s3-csi-driver/pkg/driver/node/volumecontext"
)

const (
volumeCtxBucketName = "bucketName"
defaultKubeletPath = "/var/lib/kubelet"
defaultKubeletPath = "/var/lib/kubelet"
)

var kubeletPath = getKubeletPath()
Expand Down Expand Up @@ -67,9 +67,9 @@ func NewS3NodeServer(nodeID string, mounter mounter.Mounter, credentialProvider
}

func (ns *S3NodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) {
volumeContext := req.GetVolumeContext()
if volumeContext[mounter.VolumeCtxAuthenticationSource] == mounter.AuthenticationSourcePod {
podID := volumeContext[mounter.VolumeCtxPodUID]
volumeCtx := req.GetVolumeContext()
if volumeCtx[volumecontext.AuthenticationSource] == mounter.AuthenticationSourcePod {
podID := volumeCtx[volumecontext.CSIPodUID]
volumeID := req.GetVolumeId()
if podID != "" && volumeID != "" {
err := ns.credentialProvider.CleanupToken(volumeID, podID)
Expand All @@ -94,9 +94,9 @@ func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePubl
return nil, status.Error(codes.InvalidArgument, "Volume ID not provided")
}

volumeContext := req.GetVolumeContext()
volumeCtx := req.GetVolumeContext()

bucket, ok := volumeContext[volumeCtxBucketName]
bucket, ok := volumeCtx[volumecontext.BucketName]
if !ok {
return nil, status.Error(codes.InvalidArgument, "Bucket name not provided")
}
Expand Down Expand Up @@ -295,7 +295,7 @@ func (ns *S3NodeServer) isValidVolumeCapabilities(volCaps []*csi.VolumeCapabilit
// with sensitive fields removed.
func logSafeNodePublishVolumeRequest(req *csi.NodePublishVolumeRequest) *csi.NodePublishVolumeRequest {
safeVolumeContext := maps.Clone(req.VolumeContext)
delete(safeVolumeContext, mounter.VolumeCtxServiceAccountTokens)
delete(safeVolumeContext, volumecontext.CSIServiceAccountTokens)

return &csi.NodePublishVolumeRequest{
VolumeId: req.VolumeId,
Expand Down
13 changes: 13 additions & 0 deletions pkg/driver/node/volumecontext/volume_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Package volumecontext provides utilities for accessing volume context passed via CSI RPC.
package volumecontext

const (
BucketName = "bucketName"
AuthenticationSource = "authenticationSource"
STSRegion = "stsRegion"

CSIServiceAccountName = "csi.storage.k8s.io/serviceAccount.name"
CSIServiceAccountTokens = "csi.storage.k8s.io/serviceAccount.tokens"
CSIPodNamespace = "csi.storage.k8s.io/pod.namespace"
CSIPodUID = "csi.storage.k8s.io/pod.uid"
)

0 comments on commit 8c65618

Please sign in to comment.