Skip to content

Commit

Permalink
m2m allow STS role session name to be overridden with `--aws-sts-role…
Browse files Browse the repository at this point in the history
…-session-name [value]` CLI flag.

Closes #165
  • Loading branch information
monde committed Jan 11, 2025
1 parent e26279e commit 55b7463
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 138 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ These settings are optional unless marked otherwise:
| Private Key File (**required** in lieu of private key) | File holding PEM (pkcs#1 or pkcs#9) private key whose public key is stored on the service app | `--private-key-file [value]` | `OKTA_AWSCLI_PRIVATE_KEY_FILE` |
| Authorization Server ID | The ID of the Okta authorization server, set ID for a custom authorization server, will use default otherwise. Default `default` | `--authz-id [value]` | `OKTA_AWSCLI_AUTHZ_ID` |
| Custom scope name | The custom scope established in the custom authorization server. Default `okta-m2m-access` | `--custom-scope [value]` | `OKTA_AWSCLI_CUSTOM_SCOPE` |
| Custom STS Role Session Name | Customize STS Role Session Name. Default `okta-aws-cli` | `--aws-sts-role-session-name [value]` | `OKTA_AWSCLI_STS_ROLE_SESSION_NAME` |

### Friendly IdP and Role menu labels

Expand Down
7 changes: 7 additions & 0 deletions cmd/root/m2m/m2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ var (
Usage: "Custom Authorization Server ID",
EnvVar: config.AuthzIDEnvVar,
},
{
Name: config.AWSSTSRoleSessionNameFlag,
Short: "q",
Value: "okta-aws-cli",
Usage: "STS Role Session Name",
EnvVar: config.AWSSTSRoleSessionNameEnvVar,
},
}
requiredFlags = []interface{}{"org-domain", "oidc-client-id", "aws-iam-role", "key-id", []string{"private-key", "private-key-file"}}
)
Expand Down
297 changes: 160 additions & 137 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ const (
AWSIAMRoleFlag = "aws-iam-role"
// AWSRegionFlag cli flag const
AWSRegionFlag = "aws-region"
// AWSSTSRoleSessionNameFlag cli flag const
AWSSTSRoleSessionNameFlag = "aws-sts-role-session-name"
// CustomScopeFlag cli flag const
CustomScopeFlag = "custom-scope"
// DebugFlag cli flag const
Expand Down Expand Up @@ -139,6 +141,8 @@ const (
AWSSessionDurationEnvVar = "OKTA_AWSCLI_SESSION_DURATION"
// AWSRegionEnvVar env var const
AWSRegionEnvVar = "OKTA_AWSCLI_AWS_REGION"
// AWSSTSRoleSessionNameEnvVar env var const
AWSSTSRoleSessionNameEnvVar = "OKTA_AWSCLI_STS_ROLE_SESSION_NAME"
// CacheAccessTokenEnvVar env var const
CacheAccessTokenEnvVar = "OKTA_AWSCLI_CACHE_ACCESS_TOKEN"
// CustomScopeEnvVar env var const
Expand Down Expand Up @@ -214,32 +218,33 @@ type OktaYamlConfig struct {
// name. This is a convenience struct pretty printing profile information from
// the list profiles command cmd/root/profileslist/profiles-list.go
type OktaYamlConfigProfile struct {
AllProfiles string `yaml:"all-profiles"`
AuthzID string `yaml:"authz-id"`
AWSAcctFedAppID string `yaml:"aws-acct-fed-app-id"`
AWSCredentials string `yaml:"aws-credentials"`
AWSIAMIdP string `yaml:"aws-iam-idp"`
AWSIAMRole string `yaml:"aws-iam-role"`
AWSRegion string `yaml:"aws-region"`
CustomScope string `yaml:"custom-scope"`
Debug string `yaml:"debug"`
DebugAPICalls string `yaml:"debug-api-calls"`
Exec string `yaml:"exec"`
Format string `yaml:"format"`
OIDCClientID string `yaml:"oidc-client-id"`
OpenBrowser string `yaml:"open-browser"`
OpenBrowserCommand string `yaml:"open-browser-command"`
OrgDomain string `yaml:"org-domain"`
PrivateKey string `yaml:"private-key"`
PrivateKeyFile string `yaml:"private-key-file"`
KeyID string `yaml:"key-id"`
Profile string `yaml:"profile"`
QRCode string `yaml:"qr-code"`
SessionDuration string `yaml:"session-duration"`
WriteAWSCredentials string `yaml:"write-aws-credentials"`
LegacyAWSVariables string `yaml:"legacy-aws-variables"`
ExpiryAWSVariables string `yaml:"expiry-aws-variables"`
CacheAccessToken string `yaml:"cache-access-token"`
AllProfiles string `yaml:"all-profiles"`
AuthzID string `yaml:"authz-id"`
AWSAcctFedAppID string `yaml:"aws-acct-fed-app-id"`
AWSCredentials string `yaml:"aws-credentials"`
AWSIAMIdP string `yaml:"aws-iam-idp"`
AWSIAMRole string `yaml:"aws-iam-role"`
AWSRegion string `yaml:"aws-region"`
AWSSTSRoleSessionName string `yaml:"aws-sts-role-session-name"`
CustomScope string `yaml:"custom-scope"`
Debug string `yaml:"debug"`
DebugAPICalls string `yaml:"debug-api-calls"`
Exec string `yaml:"exec"`
Format string `yaml:"format"`
OIDCClientID string `yaml:"oidc-client-id"`
OpenBrowser string `yaml:"open-browser"`
OpenBrowserCommand string `yaml:"open-browser-command"`
OrgDomain string `yaml:"org-domain"`
PrivateKey string `yaml:"private-key"`
PrivateKeyFile string `yaml:"private-key-file"`
KeyID string `yaml:"key-id"`
Profile string `yaml:"profile"`
QRCode string `yaml:"qr-code"`
SessionDuration string `yaml:"session-duration"`
WriteAWSCredentials string `yaml:"write-aws-credentials"`
LegacyAWSVariables string `yaml:"legacy-aws-variables"`
ExpiryAWSVariables string `yaml:"expiry-aws-variables"`
CacheAccessToken string `yaml:"cache-access-token"`
}

// Clock interface to abstract time operations
Expand All @@ -254,67 +259,69 @@ type Clock interface {
// control data access, be concerned with evaluation, validation, and not
// allowing direct access to values as is done on structs in the generic case.
type Config struct {
allProfiles bool
authzID string
awsCredentials string
awsIAMIdP string
awsIAMRole string
awsRegion string
awsSessionDuration int64
cacheAccessToken bool
customScope string
debug bool
debugAPICalls bool
exec bool
expiryAWSVariables bool
fedAppID string
format string
httpClient *http.Client
keyID string
legacyAWSVariables bool
oidcAppID string
openBrowser bool
openBrowserCommand string
orgDomain string
privateKey string
privateKeyFile string
profile string
qrCode bool
shortUserAgent bool
writeAWSCredentials bool
clock Clock
Logger logger.Logger
allProfiles bool
authzID string
awsCredentials string
awsIAMIdP string
awsIAMRole string
awsRegion string
awsSessionDuration int64
awsSTSRoleSessionName string
cacheAccessToken bool
customScope string
debug bool
debugAPICalls bool
exec bool
expiryAWSVariables bool
fedAppID string
format string
httpClient *http.Client
keyID string
legacyAWSVariables bool
oidcAppID string
openBrowser bool
openBrowserCommand string
orgDomain string
privateKey string
privateKeyFile string
profile string
qrCode bool
shortUserAgent bool
writeAWSCredentials bool
clock Clock
Logger logger.Logger
}

// Attributes attributes for config construction
type Attributes struct {
AllProfiles bool
AuthzID string
AWSCredentials string
AWSIAMIdP string
AWSIAMRole string
AWSRegion string
AWSSessionDuration int64
CacheAccessToken bool
CustomScope string
Debug bool
DebugAPICalls bool
Exec bool
ExpiryAWSVariables bool
FedAppID string
Format string
KeyID string
LegacyAWSVariables bool
OIDCAppID string
OpenBrowser bool
OpenBrowserCommand string
OrgDomain string
PrivateKey string
PrivateKeyFile string
Profile string
QRCode bool
ShortUserAgent bool
WriteAWSCredentials bool
AllProfiles bool
AuthzID string
AWSCredentials string
AWSIAMIdP string
AWSIAMRole string
AWSRegion string
AWSSessionDuration int64
AWSSTSRoleSessionName string
CacheAccessToken bool
CustomScope string
Debug bool
DebugAPICalls bool
Exec bool
ExpiryAWSVariables bool
FedAppID string
Format string
KeyID string
LegacyAWSVariables bool
OIDCAppID string
OpenBrowser bool
OpenBrowserCommand string
OrgDomain string
PrivateKey string
PrivateKeyFile string
Profile string
QRCode bool
ShortUserAgent bool
WriteAWSCredentials bool
}

// NewEvaluatedConfig Returns a new config loading and evaluating attributes in
Expand Down Expand Up @@ -345,33 +352,34 @@ func NewEvaluatedConfig() (*Config, error) {
func NewConfig(attrs *Attributes) (*Config, error) {
var err error
cfg := &Config{
allProfiles: attrs.AllProfiles,
authzID: attrs.AuthzID,
awsCredentials: attrs.AWSCredentials,
awsIAMIdP: attrs.AWSIAMIdP,
awsIAMRole: attrs.AWSIAMRole,
awsRegion: attrs.AWSRegion,
awsSessionDuration: attrs.AWSSessionDuration,
cacheAccessToken: attrs.CacheAccessToken,
customScope: attrs.CustomScope,
debug: attrs.Debug,
debugAPICalls: attrs.DebugAPICalls,
exec: attrs.Exec,
expiryAWSVariables: attrs.ExpiryAWSVariables,
fedAppID: attrs.FedAppID,
format: attrs.Format,
keyID: attrs.KeyID,
legacyAWSVariables: attrs.LegacyAWSVariables,
oidcAppID: attrs.OIDCAppID,
openBrowser: attrs.OpenBrowser,
openBrowserCommand: attrs.OpenBrowserCommand,
orgDomain: attrs.OrgDomain,
privateKey: attrs.PrivateKey,
privateKeyFile: attrs.PrivateKeyFile,
profile: attrs.Profile,
qrCode: attrs.QRCode,
shortUserAgent: attrs.ShortUserAgent,
writeAWSCredentials: attrs.WriteAWSCredentials,
allProfiles: attrs.AllProfiles,
authzID: attrs.AuthzID,
awsCredentials: attrs.AWSCredentials,
awsIAMIdP: attrs.AWSIAMIdP,
awsIAMRole: attrs.AWSIAMRole,
awsRegion: attrs.AWSRegion,
awsSessionDuration: attrs.AWSSessionDuration,
awsSTSRoleSessionName: attrs.AWSSTSRoleSessionName,
cacheAccessToken: attrs.CacheAccessToken,
customScope: attrs.CustomScope,
debug: attrs.Debug,
debugAPICalls: attrs.DebugAPICalls,
exec: attrs.Exec,
expiryAWSVariables: attrs.ExpiryAWSVariables,
fedAppID: attrs.FedAppID,
format: attrs.Format,
keyID: attrs.KeyID,
legacyAWSVariables: attrs.LegacyAWSVariables,
oidcAppID: attrs.OIDCAppID,
openBrowser: attrs.OpenBrowser,
openBrowserCommand: attrs.OpenBrowserCommand,
orgDomain: attrs.OrgDomain,
privateKey: attrs.PrivateKey,
privateKeyFile: attrs.PrivateKeyFile,
profile: attrs.Profile,
qrCode: attrs.QRCode,
shortUserAgent: attrs.ShortUserAgent,
writeAWSCredentials: attrs.WriteAWSCredentials,
}
err = cfg.SetOrgDomain(attrs.OrgDomain)
if err != nil {
Expand Down Expand Up @@ -462,33 +470,34 @@ func loadConfigAttributesFromFlagsAndVars() (Attributes, error) {
}

attrs := Attributes{
AllProfiles: viper.GetBool(getFlagNameFromProfile(awsProfile, AllProfilesFlag)),
AuthzID: viper.GetString(getFlagNameFromProfile(awsProfile, AuthzIDFlag)),
AWSCredentials: viper.GetString(getFlagNameFromProfile(awsProfile, AWSCredentialsFlag)),
AWSIAMIdP: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMIdPFlag)),
AWSIAMRole: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMRoleFlag)),
AWSRegion: viper.GetString(getFlagNameFromProfile(awsProfile, AWSRegionFlag)),
AWSSessionDuration: viper.GetInt64(getFlagNameFromProfile(awsProfile, SessionDurationFlag)),
CustomScope: viper.GetString(getFlagNameFromProfile(awsProfile, CustomScopeFlag)),
Debug: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugFlag)),
DebugAPICalls: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugAPICallsFlag)),
Exec: viper.GetBool(getFlagNameFromProfile(awsProfile, ExecFlag)),
FedAppID: viper.GetString(getFlagNameFromProfile(awsProfile, AWSAcctFedAppIDFlag)),
Format: viper.GetString(getFlagNameFromProfile(awsProfile, FormatFlag)),
LegacyAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, LegacyAWSVariablesFlag)),
ExpiryAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, ExpiryAWSVariablesFlag)),
CacheAccessToken: viper.GetBool(getFlagNameFromProfile(awsProfile, CacheAccessTokenFlag)),
OIDCAppID: viper.GetString(getFlagNameFromProfile(awsProfile, OIDCClientIDFlag)),
OpenBrowser: viper.GetBool(getFlagNameFromProfile(awsProfile, OpenBrowserFlag)),
OpenBrowserCommand: viper.GetString(getFlagNameFromProfile(awsProfile, OpenBrowserCommandFlag)),
OrgDomain: viper.GetString(getFlagNameFromProfile(awsProfile, OrgDomainFlag)),
PrivateKey: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFlag)),
PrivateKeyFile: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFileFlag)),
KeyID: viper.GetString(getFlagNameFromProfile(awsProfile, KeyIDFlag)),
Profile: awsProfile,
QRCode: viper.GetBool(getFlagNameFromProfile(awsProfile, QRCodeFlag)),
ShortUserAgent: viper.GetBool(getFlagNameFromProfile(awsProfile, ShortUserAgentFlag)),
WriteAWSCredentials: viper.GetBool(getFlagNameFromProfile(awsProfile, WriteAWSCredentialsFlag)),
AllProfiles: viper.GetBool(getFlagNameFromProfile(awsProfile, AllProfilesFlag)),
AuthzID: viper.GetString(getFlagNameFromProfile(awsProfile, AuthzIDFlag)),
AWSCredentials: viper.GetString(getFlagNameFromProfile(awsProfile, AWSCredentialsFlag)),
AWSIAMIdP: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMIdPFlag)),
AWSIAMRole: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMRoleFlag)),
AWSRegion: viper.GetString(getFlagNameFromProfile(awsProfile, AWSRegionFlag)),
AWSSessionDuration: viper.GetInt64(getFlagNameFromProfile(awsProfile, SessionDurationFlag)),
AWSSTSRoleSessionName: viper.GetString(getFlagNameFromProfile(awsProfile, AWSSTSRoleSessionNameFlag)),
CustomScope: viper.GetString(getFlagNameFromProfile(awsProfile, CustomScopeFlag)),
Debug: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugFlag)),
DebugAPICalls: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugAPICallsFlag)),
Exec: viper.GetBool(getFlagNameFromProfile(awsProfile, ExecFlag)),
FedAppID: viper.GetString(getFlagNameFromProfile(awsProfile, AWSAcctFedAppIDFlag)),
Format: viper.GetString(getFlagNameFromProfile(awsProfile, FormatFlag)),
LegacyAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, LegacyAWSVariablesFlag)),
ExpiryAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, ExpiryAWSVariablesFlag)),
CacheAccessToken: viper.GetBool(getFlagNameFromProfile(awsProfile, CacheAccessTokenFlag)),
OIDCAppID: viper.GetString(getFlagNameFromProfile(awsProfile, OIDCClientIDFlag)),
OpenBrowser: viper.GetBool(getFlagNameFromProfile(awsProfile, OpenBrowserFlag)),
OpenBrowserCommand: viper.GetString(getFlagNameFromProfile(awsProfile, OpenBrowserCommandFlag)),
OrgDomain: viper.GetString(getFlagNameFromProfile(awsProfile, OrgDomainFlag)),
PrivateKey: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFlag)),
PrivateKeyFile: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFileFlag)),
KeyID: viper.GetString(getFlagNameFromProfile(awsProfile, KeyIDFlag)),
Profile: awsProfile,
QRCode: viper.GetBool(getFlagNameFromProfile(awsProfile, QRCodeFlag)),
ShortUserAgent: viper.GetBool(getFlagNameFromProfile(awsProfile, ShortUserAgentFlag)),
WriteAWSCredentials: viper.GetBool(getFlagNameFromProfile(awsProfile, WriteAWSCredentialsFlag)),
}
if attrs.Format == "" {
attrs.Format = EnvVarFormat
Expand Down Expand Up @@ -521,6 +530,9 @@ func loadConfigAttributesFromFlagsAndVars() (Attributes, error) {
if attrs.AWSIAMRole == "" {
attrs.AWSIAMRole = viper.GetString(downCase(AWSIAMRoleEnvVar))
}
if attrs.AWSSTSRoleSessionName == "" {
attrs.AWSSTSRoleSessionName = viper.GetString(downCase(AWSSTSRoleSessionNameEnvVar))
}
if !attrs.QRCode {
attrs.QRCode = viper.GetBool(downCase(QRCodeEnvVar))
}
Expand Down Expand Up @@ -722,6 +734,17 @@ func (c *Config) SetAWSSessionDuration(duration int64) error {
return nil
}

// AWSSTSRoleSessionName --
func (c *Config) AWSSTSRoleSessionName() string {
return c.awsSTSRoleSessionName
}

// SetAWSSTSRoleSessionName --
func (c *Config) SetAWSSTSRoleSessionName(name string) error {
c.awsSTSRoleSessionName = name
return nil
}

// CacheAccessToken --
func (c *Config) CacheAccessToken() bool {
return c.cacheAccessToken
Expand Down
2 changes: 1 addition & 1 deletion internal/m2mauth/m2mauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (m *M2MAuthentication) awsAssumeRoleWithWebIdentity(at *okta.AccessToken) (
input := &sts.AssumeRoleWithWebIdentityInput{
DurationSeconds: aws.Int64(m.config.AWSSessionDuration()),
RoleArn: aws.String(m.config.AWSIAMRole()),
RoleSessionName: aws.String("okta-aws-cli"),
RoleSessionName: aws.String(m.config.AWSSTSRoleSessionName()),
WebIdentityToken: &at.AccessToken,
}
svcResp, err := svc.AssumeRoleWithWebIdentity(input)
Expand Down

0 comments on commit 55b7463

Please sign in to comment.