diff --git a/README.md b/README.md index dc04833..491f434 100644 --- a/README.md +++ b/README.md @@ -470,6 +470,7 @@ These settings are optional unless marked otherwise: | Private Key File (**required** in lieu of private key) | File holding PEM (pkcs#1 or pkcs#9) private key whose public key is stored on the service app | `--private-key-file [value]` | `OKTA_AWSCLI_PRIVATE_KEY_FILE` | | Authorization Server ID | The ID of the Okta authorization server, set ID for a custom authorization server, will use default otherwise. Default `default` | `--authz-id [value]` | `OKTA_AWSCLI_AUTHZ_ID` | | Custom scope name | The custom scope established in the custom authorization server. Default `okta-m2m-access` | `--custom-scope [value]` | `OKTA_AWSCLI_CUSTOM_SCOPE` | +| Custom STS Role Session Name | Customize STS Role Session Name. Default `okta-aws-cli` | `--aws-sts-role-session-name [value]` | `OKTA_AWSCLI_STS_ROLE_SESSION_NAME` | ### Friendly IdP and Role menu labels diff --git a/cmd/root/m2m/m2m.go b/cmd/root/m2m/m2m.go index cb99066..2281db0 100644 --- a/cmd/root/m2m/m2m.go +++ b/cmd/root/m2m/m2m.go @@ -61,6 +61,13 @@ var ( Usage: "Custom Authorization Server ID", EnvVar: config.AuthzIDEnvVar, }, + { + Name: config.AWSSTSRoleSessionNameFlag, + Short: "q", + Value: "okta-aws-cli", + Usage: "STS Role Session Name", + EnvVar: config.AWSSTSRoleSessionNameEnvVar, + }, } requiredFlags = []interface{}{"org-domain", "oidc-client-id", "aws-iam-role", "key-id", []string{"private-key", "private-key-file"}} ) diff --git a/internal/config/config.go b/internal/config/config.go index 8f19fff..ffd9604 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -80,6 +80,8 @@ const ( AWSIAMRoleFlag = "aws-iam-role" // AWSRegionFlag cli flag const AWSRegionFlag = "aws-region" + // AWSSTSRoleSessionNameFlag cli flag const + AWSSTSRoleSessionNameFlag = "aws-sts-role-session-name" // CustomScopeFlag cli flag const CustomScopeFlag = "custom-scope" // DebugFlag cli flag const @@ -139,6 +141,8 @@ const ( AWSSessionDurationEnvVar = "OKTA_AWSCLI_SESSION_DURATION" // AWSRegionEnvVar env var const AWSRegionEnvVar = "OKTA_AWSCLI_AWS_REGION" + // AWSSTSRoleSessionNameEnvVar env var const + AWSSTSRoleSessionNameEnvVar = "OKTA_AWSCLI_STS_ROLE_SESSION_NAME" // CacheAccessTokenEnvVar env var const CacheAccessTokenEnvVar = "OKTA_AWSCLI_CACHE_ACCESS_TOKEN" // CustomScopeEnvVar env var const @@ -214,32 +218,33 @@ type OktaYamlConfig struct { // name. This is a convenience struct pretty printing profile information from // the list profiles command cmd/root/profileslist/profiles-list.go type OktaYamlConfigProfile struct { - AllProfiles string `yaml:"all-profiles"` - AuthzID string `yaml:"authz-id"` - AWSAcctFedAppID string `yaml:"aws-acct-fed-app-id"` - AWSCredentials string `yaml:"aws-credentials"` - AWSIAMIdP string `yaml:"aws-iam-idp"` - AWSIAMRole string `yaml:"aws-iam-role"` - AWSRegion string `yaml:"aws-region"` - CustomScope string `yaml:"custom-scope"` - Debug string `yaml:"debug"` - DebugAPICalls string `yaml:"debug-api-calls"` - Exec string `yaml:"exec"` - Format string `yaml:"format"` - OIDCClientID string `yaml:"oidc-client-id"` - OpenBrowser string `yaml:"open-browser"` - OpenBrowserCommand string `yaml:"open-browser-command"` - OrgDomain string `yaml:"org-domain"` - PrivateKey string `yaml:"private-key"` - PrivateKeyFile string `yaml:"private-key-file"` - KeyID string `yaml:"key-id"` - Profile string `yaml:"profile"` - QRCode string `yaml:"qr-code"` - SessionDuration string `yaml:"session-duration"` - WriteAWSCredentials string `yaml:"write-aws-credentials"` - LegacyAWSVariables string `yaml:"legacy-aws-variables"` - ExpiryAWSVariables string `yaml:"expiry-aws-variables"` - CacheAccessToken string `yaml:"cache-access-token"` + AllProfiles string `yaml:"all-profiles"` + AuthzID string `yaml:"authz-id"` + AWSAcctFedAppID string `yaml:"aws-acct-fed-app-id"` + AWSCredentials string `yaml:"aws-credentials"` + AWSIAMIdP string `yaml:"aws-iam-idp"` + AWSIAMRole string `yaml:"aws-iam-role"` + AWSRegion string `yaml:"aws-region"` + AWSSTSRoleSessionName string `yaml:"aws-sts-role-session-name"` + CustomScope string `yaml:"custom-scope"` + Debug string `yaml:"debug"` + DebugAPICalls string `yaml:"debug-api-calls"` + Exec string `yaml:"exec"` + Format string `yaml:"format"` + OIDCClientID string `yaml:"oidc-client-id"` + OpenBrowser string `yaml:"open-browser"` + OpenBrowserCommand string `yaml:"open-browser-command"` + OrgDomain string `yaml:"org-domain"` + PrivateKey string `yaml:"private-key"` + PrivateKeyFile string `yaml:"private-key-file"` + KeyID string `yaml:"key-id"` + Profile string `yaml:"profile"` + QRCode string `yaml:"qr-code"` + SessionDuration string `yaml:"session-duration"` + WriteAWSCredentials string `yaml:"write-aws-credentials"` + LegacyAWSVariables string `yaml:"legacy-aws-variables"` + ExpiryAWSVariables string `yaml:"expiry-aws-variables"` + CacheAccessToken string `yaml:"cache-access-token"` } // Clock interface to abstract time operations @@ -254,67 +259,69 @@ type Clock interface { // control data access, be concerned with evaluation, validation, and not // allowing direct access to values as is done on structs in the generic case. type Config struct { - allProfiles bool - authzID string - awsCredentials string - awsIAMIdP string - awsIAMRole string - awsRegion string - awsSessionDuration int64 - cacheAccessToken bool - customScope string - debug bool - debugAPICalls bool - exec bool - expiryAWSVariables bool - fedAppID string - format string - httpClient *http.Client - keyID string - legacyAWSVariables bool - oidcAppID string - openBrowser bool - openBrowserCommand string - orgDomain string - privateKey string - privateKeyFile string - profile string - qrCode bool - shortUserAgent bool - writeAWSCredentials bool - clock Clock - Logger logger.Logger + allProfiles bool + authzID string + awsCredentials string + awsIAMIdP string + awsIAMRole string + awsRegion string + awsSessionDuration int64 + awsSTSRoleSessionName string + cacheAccessToken bool + customScope string + debug bool + debugAPICalls bool + exec bool + expiryAWSVariables bool + fedAppID string + format string + httpClient *http.Client + keyID string + legacyAWSVariables bool + oidcAppID string + openBrowser bool + openBrowserCommand string + orgDomain string + privateKey string + privateKeyFile string + profile string + qrCode bool + shortUserAgent bool + writeAWSCredentials bool + clock Clock + Logger logger.Logger } // Attributes attributes for config construction type Attributes struct { - AllProfiles bool - AuthzID string - AWSCredentials string - AWSIAMIdP string - AWSIAMRole string - AWSRegion string - AWSSessionDuration int64 - CacheAccessToken bool - CustomScope string - Debug bool - DebugAPICalls bool - Exec bool - ExpiryAWSVariables bool - FedAppID string - Format string - KeyID string - LegacyAWSVariables bool - OIDCAppID string - OpenBrowser bool - OpenBrowserCommand string - OrgDomain string - PrivateKey string - PrivateKeyFile string - Profile string - QRCode bool - ShortUserAgent bool - WriteAWSCredentials bool + AllProfiles bool + AuthzID string + AWSCredentials string + AWSIAMIdP string + AWSIAMRole string + AWSRegion string + AWSSessionDuration int64 + AWSSTSRoleSessionName string + CacheAccessToken bool + CustomScope string + Debug bool + DebugAPICalls bool + Exec bool + ExpiryAWSVariables bool + FedAppID string + Format string + KeyID string + LegacyAWSVariables bool + OIDCAppID string + OpenBrowser bool + OpenBrowserCommand string + OrgDomain string + PrivateKey string + PrivateKeyFile string + Profile string + QRCode bool + ShortUserAgent bool + WriteAWSCredentials bool } // NewEvaluatedConfig Returns a new config loading and evaluating attributes in @@ -345,33 +352,34 @@ func NewEvaluatedConfig() (*Config, error) { func NewConfig(attrs *Attributes) (*Config, error) { var err error cfg := &Config{ - allProfiles: attrs.AllProfiles, - authzID: attrs.AuthzID, - awsCredentials: attrs.AWSCredentials, - awsIAMIdP: attrs.AWSIAMIdP, - awsIAMRole: attrs.AWSIAMRole, - awsRegion: attrs.AWSRegion, - awsSessionDuration: attrs.AWSSessionDuration, - cacheAccessToken: attrs.CacheAccessToken, - customScope: attrs.CustomScope, - debug: attrs.Debug, - debugAPICalls: attrs.DebugAPICalls, - exec: attrs.Exec, - expiryAWSVariables: attrs.ExpiryAWSVariables, - fedAppID: attrs.FedAppID, - format: attrs.Format, - keyID: attrs.KeyID, - legacyAWSVariables: attrs.LegacyAWSVariables, - oidcAppID: attrs.OIDCAppID, - openBrowser: attrs.OpenBrowser, - openBrowserCommand: attrs.OpenBrowserCommand, - orgDomain: attrs.OrgDomain, - privateKey: attrs.PrivateKey, - privateKeyFile: attrs.PrivateKeyFile, - profile: attrs.Profile, - qrCode: attrs.QRCode, - shortUserAgent: attrs.ShortUserAgent, - writeAWSCredentials: attrs.WriteAWSCredentials, + allProfiles: attrs.AllProfiles, + authzID: attrs.AuthzID, + awsCredentials: attrs.AWSCredentials, + awsIAMIdP: attrs.AWSIAMIdP, + awsIAMRole: attrs.AWSIAMRole, + awsRegion: attrs.AWSRegion, + awsSessionDuration: attrs.AWSSessionDuration, + awsSTSRoleSessionName: attrs.AWSSTSRoleSessionName, + cacheAccessToken: attrs.CacheAccessToken, + customScope: attrs.CustomScope, + debug: attrs.Debug, + debugAPICalls: attrs.DebugAPICalls, + exec: attrs.Exec, + expiryAWSVariables: attrs.ExpiryAWSVariables, + fedAppID: attrs.FedAppID, + format: attrs.Format, + keyID: attrs.KeyID, + legacyAWSVariables: attrs.LegacyAWSVariables, + oidcAppID: attrs.OIDCAppID, + openBrowser: attrs.OpenBrowser, + openBrowserCommand: attrs.OpenBrowserCommand, + orgDomain: attrs.OrgDomain, + privateKey: attrs.PrivateKey, + privateKeyFile: attrs.PrivateKeyFile, + profile: attrs.Profile, + qrCode: attrs.QRCode, + shortUserAgent: attrs.ShortUserAgent, + writeAWSCredentials: attrs.WriteAWSCredentials, } err = cfg.SetOrgDomain(attrs.OrgDomain) if err != nil { @@ -462,33 +470,34 @@ func loadConfigAttributesFromFlagsAndVars() (Attributes, error) { } attrs := Attributes{ - AllProfiles: viper.GetBool(getFlagNameFromProfile(awsProfile, AllProfilesFlag)), - AuthzID: viper.GetString(getFlagNameFromProfile(awsProfile, AuthzIDFlag)), - AWSCredentials: viper.GetString(getFlagNameFromProfile(awsProfile, AWSCredentialsFlag)), - AWSIAMIdP: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMIdPFlag)), - AWSIAMRole: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMRoleFlag)), - AWSRegion: viper.GetString(getFlagNameFromProfile(awsProfile, AWSRegionFlag)), - AWSSessionDuration: viper.GetInt64(getFlagNameFromProfile(awsProfile, SessionDurationFlag)), - CustomScope: viper.GetString(getFlagNameFromProfile(awsProfile, CustomScopeFlag)), - Debug: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugFlag)), - DebugAPICalls: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugAPICallsFlag)), - Exec: viper.GetBool(getFlagNameFromProfile(awsProfile, ExecFlag)), - FedAppID: viper.GetString(getFlagNameFromProfile(awsProfile, AWSAcctFedAppIDFlag)), - Format: viper.GetString(getFlagNameFromProfile(awsProfile, FormatFlag)), - LegacyAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, LegacyAWSVariablesFlag)), - ExpiryAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, ExpiryAWSVariablesFlag)), - CacheAccessToken: viper.GetBool(getFlagNameFromProfile(awsProfile, CacheAccessTokenFlag)), - OIDCAppID: viper.GetString(getFlagNameFromProfile(awsProfile, OIDCClientIDFlag)), - OpenBrowser: viper.GetBool(getFlagNameFromProfile(awsProfile, OpenBrowserFlag)), - OpenBrowserCommand: viper.GetString(getFlagNameFromProfile(awsProfile, OpenBrowserCommandFlag)), - OrgDomain: viper.GetString(getFlagNameFromProfile(awsProfile, OrgDomainFlag)), - PrivateKey: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFlag)), - PrivateKeyFile: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFileFlag)), - KeyID: viper.GetString(getFlagNameFromProfile(awsProfile, KeyIDFlag)), - Profile: awsProfile, - QRCode: viper.GetBool(getFlagNameFromProfile(awsProfile, QRCodeFlag)), - ShortUserAgent: viper.GetBool(getFlagNameFromProfile(awsProfile, ShortUserAgentFlag)), - WriteAWSCredentials: viper.GetBool(getFlagNameFromProfile(awsProfile, WriteAWSCredentialsFlag)), + AllProfiles: viper.GetBool(getFlagNameFromProfile(awsProfile, AllProfilesFlag)), + AuthzID: viper.GetString(getFlagNameFromProfile(awsProfile, AuthzIDFlag)), + AWSCredentials: viper.GetString(getFlagNameFromProfile(awsProfile, AWSCredentialsFlag)), + AWSIAMIdP: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMIdPFlag)), + AWSIAMRole: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMRoleFlag)), + AWSRegion: viper.GetString(getFlagNameFromProfile(awsProfile, AWSRegionFlag)), + AWSSessionDuration: viper.GetInt64(getFlagNameFromProfile(awsProfile, SessionDurationFlag)), + AWSSTSRoleSessionName: viper.GetString(getFlagNameFromProfile(awsProfile, AWSSTSRoleSessionNameFlag)), + CustomScope: viper.GetString(getFlagNameFromProfile(awsProfile, CustomScopeFlag)), + Debug: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugFlag)), + DebugAPICalls: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugAPICallsFlag)), + Exec: viper.GetBool(getFlagNameFromProfile(awsProfile, ExecFlag)), + FedAppID: viper.GetString(getFlagNameFromProfile(awsProfile, AWSAcctFedAppIDFlag)), + Format: viper.GetString(getFlagNameFromProfile(awsProfile, FormatFlag)), + LegacyAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, LegacyAWSVariablesFlag)), + ExpiryAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, ExpiryAWSVariablesFlag)), + CacheAccessToken: viper.GetBool(getFlagNameFromProfile(awsProfile, CacheAccessTokenFlag)), + OIDCAppID: viper.GetString(getFlagNameFromProfile(awsProfile, OIDCClientIDFlag)), + OpenBrowser: viper.GetBool(getFlagNameFromProfile(awsProfile, OpenBrowserFlag)), + OpenBrowserCommand: viper.GetString(getFlagNameFromProfile(awsProfile, OpenBrowserCommandFlag)), + OrgDomain: viper.GetString(getFlagNameFromProfile(awsProfile, OrgDomainFlag)), + PrivateKey: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFlag)), + PrivateKeyFile: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFileFlag)), + KeyID: viper.GetString(getFlagNameFromProfile(awsProfile, KeyIDFlag)), + Profile: awsProfile, + QRCode: viper.GetBool(getFlagNameFromProfile(awsProfile, QRCodeFlag)), + ShortUserAgent: viper.GetBool(getFlagNameFromProfile(awsProfile, ShortUserAgentFlag)), + WriteAWSCredentials: viper.GetBool(getFlagNameFromProfile(awsProfile, WriteAWSCredentialsFlag)), } if attrs.Format == "" { attrs.Format = EnvVarFormat @@ -521,6 +530,9 @@ func loadConfigAttributesFromFlagsAndVars() (Attributes, error) { if attrs.AWSIAMRole == "" { attrs.AWSIAMRole = viper.GetString(downCase(AWSIAMRoleEnvVar)) } + if attrs.AWSSTSRoleSessionName == "" { + attrs.AWSSTSRoleSessionName = viper.GetString(downCase(AWSSTSRoleSessionNameEnvVar)) + } if !attrs.QRCode { attrs.QRCode = viper.GetBool(downCase(QRCodeEnvVar)) } @@ -722,6 +734,17 @@ func (c *Config) SetAWSSessionDuration(duration int64) error { return nil } +// AWSSTSRoleSessionName -- +func (c *Config) AWSSTSRoleSessionName() string { + return c.awsSTSRoleSessionName +} + +// SetAWSSTSRoleSessionName -- +func (c *Config) SetAWSSTSRoleSessionName(name string) error { + c.awsSTSRoleSessionName = name + return nil +} + // CacheAccessToken -- func (c *Config) CacheAccessToken() bool { return c.cacheAccessToken diff --git a/internal/m2mauth/m2mauth.go b/internal/m2mauth/m2mauth.go index 5083333..740d1f2 100644 --- a/internal/m2mauth/m2mauth.go +++ b/internal/m2mauth/m2mauth.go @@ -127,7 +127,7 @@ func (m *M2MAuthentication) awsAssumeRoleWithWebIdentity(at *okta.AccessToken) ( input := &sts.AssumeRoleWithWebIdentityInput{ DurationSeconds: aws.Int64(m.config.AWSSessionDuration()), RoleArn: aws.String(m.config.AWSIAMRole()), - RoleSessionName: aws.String("okta-aws-cli"), + RoleSessionName: aws.String(m.config.AWSSTSRoleSessionName()), WebIdentityToken: &at.AccessToken, } svcResp, err := svc.AssumeRoleWithWebIdentity(input)