Skip to content

Commit

Permalink
Refactor AWS identity functions to aws_helper package
Browse files Browse the repository at this point in the history
  • Loading branch information
bwhaley committed Dec 18, 2019
1 parent 9ec6e4e commit a4e647a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 46 deletions.
52 changes: 51 additions & 1 deletion aws_helper/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package aws_helper

import (
"fmt"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/gruntwork-io/terragrunt/errors"
"github.com/gruntwork-io/terragrunt/options"
"time"
)

// A representation of the configuration options for an AWS Session
Expand Down Expand Up @@ -102,3 +103,52 @@ func AssumeIamRole(iamRoleArn string) (*sts.Credentials, error) {

return output.Credentials, nil
}

// Return the AWS caller identity associated with the current set of credentials
func GetAWSCallerIdentity(terragruntOptions *options.TerragruntOptions) (sts.GetCallerIdentityOutput, error) {
sess, err := session.NewSession()
if err != nil {
return sts.GetCallerIdentityOutput{}, errors.WithStackTrace(err)
}

if terragruntOptions.IamRole != "" {
sess.Config.Credentials = stscreds.NewCredentials(sess, terragruntOptions.IamRole)
}

identity, err := sts.New(sess).GetCallerIdentity(nil)
if err != nil {
return sts.GetCallerIdentityOutput{}, errors.WithStackTrace(err)
}

return *identity, nil
}

// Get the AWS account ID of the current session configuration
func GetAWSAccountID(terragruntOptions *options.TerragruntOptions) (string, error) {
identity, err := GetAWSCallerIdentity(terragruntOptions)
if err != nil {
return "", errors.WithStackTrace(err)
}

return *identity.Account, nil
}

// Get the ARN of the AWS identity associated with the current set of credentials
func GetAWSIdentityArn(terragruntOptions *options.TerragruntOptions) (string, error) {
identity, err := GetAWSCallerIdentity(terragruntOptions)
if err != nil {
return "", errors.WithStackTrace(err)
}

return *identity.Arn, nil
}

// Get the AWS user ID of the current session configuration
func GetAWSUserID(terragruntOptions *options.TerragruntOptions) (string, error) {
identity, err := GetAWSCallerIdentity(terragruntOptions)
if err != nil {
return "", errors.WithStackTrace(err)
}

return *identity.UserId, nil
}
35 changes: 7 additions & 28 deletions config/config_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@ import (
"fmt"
"path/filepath"

"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/hcl2/hcl"
tflang "github.com/hashicorp/terraform/lang"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/function"

"github.com/gruntwork-io/terragrunt/aws_helper"
"github.com/gruntwork-io/terragrunt/errors"
"github.com/gruntwork-io/terragrunt/options"
"github.com/gruntwork-io/terragrunt/shell"
Expand Down Expand Up @@ -302,48 +300,29 @@ func pathRelativeFromInclude(include *IncludeConfig, terragruntOptions *options.
return util.GetPathRelativeTo(includePath, currentPath)
}

// Return the AWS caller identity associated with the current set of credentials
func getAWSCallerID(include *IncludeConfig, terragruntOptions *options.TerragruntOptions) (sts.GetCallerIdentityOutput, error) {
sess, err := session.NewSession()
if err != nil {
return sts.GetCallerIdentityOutput{}, errors.WithStackTrace(err)
}

if terragruntOptions.IamRole != "" {
sess.Config.Credentials = stscreds.NewCredentials(sess, terragruntOptions.IamRole)
}

identity, err := sts.New(sess).GetCallerIdentity(nil)
if err != nil {
return sts.GetCallerIdentityOutput{}, errors.WithStackTrace(err)
}

return *identity, nil
}

// Return the AWS account id associated to the current set of credentials
func getAWSAccountID(include *IncludeConfig, terragruntOptions *options.TerragruntOptions) (string, error) {
identity, err := getAWSCallerID(include, terragruntOptions)
accountID, err := aws_helper.GetAWSAccountID(terragruntOptions)
if err == nil {
return *identity.Account, nil
return accountID, nil
}
return "", err
}

// Return the ARN of the AWS identity associated with the current set of credentials
func getAWSCallerIdentityARN(include *IncludeConfig, terragruntOptions *options.TerragruntOptions) (string, error) {
identity, err := getAWSCallerID(include, terragruntOptions)
identityARN, err := aws_helper.GetAWSIdentityArn(terragruntOptions)
if err == nil {
return *identity.Arn, nil
return identityARN, nil
}
return "", err
}

// Return the UserID of the AWS identity associated with the current set of credentials
func getAWSCallerIdentityUserID(include *IncludeConfig, terragruntOptions *options.TerragruntOptions) (string, error) {
identity, err := getAWSCallerID(include, terragruntOptions)
userID, err := aws_helper.GetAWSUserID(terragruntOptions)
if err == nil {
return *identity.UserId, nil
return userID, nil
}
return "", err
}
Expand Down
20 changes: 3 additions & 17 deletions remote/remote_state_s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/gruntwork-io/terragrunt/aws_helper"
"github.com/gruntwork-io/terragrunt/dynamodb"
"github.com/gruntwork-io/terragrunt/errors"
Expand Down Expand Up @@ -468,24 +467,11 @@ func isBucketAlreadyOwnedByYourError(err error) bool {
return isAwsErr && (awsErr.Code() == "BucketAlreadyOwnedByYou" || awsErr.Code() == "OperationAborted")
}

// Get the AWS account ID of the current session configuration
func getAWSAccountID(config *aws_helper.AwsSessionConfig, terragruntOptions *options.TerragruntOptions) (string, error) {
session, err := aws_helper.CreateAwsSession(config, terragruntOptions)
if err != nil {
return "", err
}

identity, err := sts.New(session).GetCallerIdentity(nil)
if err != nil {
return "", errors.WithStackTrace(err)
}

return *identity.Account, nil
}

// Create the S3 bucket specified in the given config
func EnableRootAccesstoS3Bucket(s3Client *s3.S3, config *RemoteStateConfigS3, terragruntOptions *options.TerragruntOptions) error {
accountID, err := getAWSAccountID(config.GetAwsSessionConfig(), terragruntOptions)
terragruntOptions.Logger.Printf("Enabling root access to S3 bucket %s", config.Bucket)

accountID, err := aws_helper.GetAWSAccountID(terragruntOptions)
if err != nil {
return errors.WithStackTrace(err)
}
Expand Down

0 comments on commit a4e647a

Please sign in to comment.