diff --git a/internal/config/config.go b/internal/config/config.go index 7d035c9..93e9920 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -315,34 +315,44 @@ func NewConfig(attrs *Attributes) (*Config, error) { return cfg, nil } +func getFlagNameFromProfile(awsProfile string, flag string) string { + profileKey := fmt.Sprintf("%s.%s", awsProfile, flag) + if awsProfile != "" && viper.IsSet(profileKey) == true { + return profileKey + } + return flag +} + func readConfig() (Attributes, error) { + awsProfile := viper.GetString(ProfileFlag) + attrs := Attributes{ - AllProfiles: viper.GetBool(AllProfilesFlag), - AuthzID: viper.GetString(AuthzIDFlag), - AWSCredentials: viper.GetString(AWSCredentialsFlag), - AWSIAMIdP: viper.GetString(AWSIAMIdPFlag), - AWSIAMRole: viper.GetString(AWSIAMRoleFlag), - AWSSessionDuration: viper.GetInt64(SessionDurationFlag), - AWSRegion: viper.GetString(AWSRegionFlag), - CustomScope: viper.GetString(CustomScopeFlag), - Debug: viper.GetBool(DebugFlag), - DebugAPICalls: viper.GetBool(DebugAPICallsFlag), - Exec: viper.GetBool(ExecFlag), - FedAppID: viper.GetString(AWSAcctFedAppIDFlag), - Format: viper.GetString(FormatFlag), - LegacyAWSVariables: viper.GetBool(LegacyAWSVariablesFlag), - ExpiryAWSVariables: viper.GetBool(ExpiryAWSVariablesFlag), - CacheAccessToken: viper.GetBool(CacheAccessTokenFlag), - OIDCAppID: viper.GetString(OIDCClientIDFlag), - OpenBrowser: viper.GetBool(OpenBrowserFlag), - OpenBrowserCommand: viper.GetString(OpenBrowserCommandFlag), - OrgDomain: viper.GetString(OrgDomainFlag), - PrivateKey: viper.GetString(PrivateKeyFlag), - PrivateKeyFile: viper.GetString(PrivateKeyFileFlag), - KeyID: viper.GetString(KeyIDFlag), - Profile: viper.GetString(ProfileFlag), - QRCode: viper.GetBool(QRCodeFlag), - WriteAWSCredentials: viper.GetBool(WriteAWSCredentialsFlag), + AllProfiles: viper.GetBool(getFlagNameFromProfile(awsProfile, AllProfilesFlag)), + AuthzID: viper.GetString(getFlagNameFromProfile(awsProfile, AuthzIDFlag)), + AWSCredentials: viper.GetString(getFlagNameFromProfile(awsProfile, AWSCredentialsFlag)), + AWSIAMIdP: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMIdPFlag)), + AWSIAMRole: viper.GetString(getFlagNameFromProfile(awsProfile, AWSIAMRoleFlag)), + AWSRegion: viper.GetString(getFlagNameFromProfile(awsProfile, AWSRegionFlag)), + AWSSessionDuration: viper.GetInt64(getFlagNameFromProfile(awsProfile, SessionDurationFlag)), + CustomScope: viper.GetString(getFlagNameFromProfile(awsProfile, CustomScopeFlag)), + Debug: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugFlag)), + DebugAPICalls: viper.GetBool(getFlagNameFromProfile(awsProfile, DebugAPICallsFlag)), + Exec: viper.GetBool(getFlagNameFromProfile(awsProfile, ExecFlag)), + FedAppID: viper.GetString(getFlagNameFromProfile(awsProfile, AWSAcctFedAppIDFlag)), + Format: viper.GetString(getFlagNameFromProfile(awsProfile, FormatFlag)), + LegacyAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, LegacyAWSVariablesFlag)), + ExpiryAWSVariables: viper.GetBool(getFlagNameFromProfile(awsProfile, ExpiryAWSVariablesFlag)), + CacheAccessToken: viper.GetBool(getFlagNameFromProfile(awsProfile, CacheAccessTokenFlag)), + OIDCAppID: viper.GetString(getFlagNameFromProfile(awsProfile, OIDCClientIDFlag)), + OpenBrowser: viper.GetBool(getFlagNameFromProfile(awsProfile, OpenBrowserFlag)), + OpenBrowserCommand: viper.GetString(getFlagNameFromProfile(awsProfile, OpenBrowserCommandFlag)), + OrgDomain: viper.GetString(getFlagNameFromProfile(awsProfile, OrgDomainFlag)), + PrivateKey: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFlag)), + PrivateKeyFile: viper.GetString(getFlagNameFromProfile(awsProfile, PrivateKeyFileFlag)), + KeyID: viper.GetString(getFlagNameFromProfile(awsProfile, KeyIDFlag)), + Profile: viper.GetString(getFlagNameFromProfile(awsProfile, ProfileFlag)), + QRCode: viper.GetBool(getFlagNameFromProfile(awsProfile, QRCodeFlag)), + WriteAWSCredentials: viper.GetBool(getFlagNameFromProfile(awsProfile, WriteAWSCredentialsFlag)), } if attrs.Format == "" { attrs.Format = EnvVarFormat diff --git a/internal/flag/flag.go b/internal/flag/flag.go index 1317f8f..c6d12d1 100644 --- a/internal/flag/flag.go +++ b/internal/flag/flag.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "os" + "os/user" "path/filepath" "strings" @@ -81,7 +82,30 @@ func MakeFlagBindings(cmd *cobra.Command, flags []Flag, persistent bool) { if vipAwsRegion != "" && os.Getenv(awsRegionEnvVar) == "" { _ = os.Setenv(awsRegionEnvVar, vipAwsRegion) } + } else { + // Check if .okta-aws-cli/conifg.yml exists + usr, err := user.Current() + if err == nil { + oktaConfig := filepath.Join(usr.HomeDir, ".okta-aws-cli", "config.yml") + if _, err := os.Stat(oktaConfig); err == nil || !errors.Is(err, os.ErrNotExist) { + viper.AddConfigPath(filepath.Join(usr.HomeDir, ".okta-aws-cli")) + viper.SetConfigName("config.yml") + viper.SetConfigType("yml") + + _ = viper.ReadInConfig() + + // After viper reads in the dotenv file check if AWS_REGION is set + // there. The value will be keyed by lower case name. If it is, set + // AWS_REGION as an ENV VAR if it hasn't already been. + awsRegionEnvVar := "AWS_REGION" + vipAwsRegion := viper.GetString(strings.ToLower(awsRegionEnvVar)) + if vipAwsRegion != "" && os.Getenv(awsRegionEnvVar) == "" { + _ = os.Setenv(awsRegionEnvVar, vipAwsRegion) + } + } + } } + viper.AutomaticEnv() // bind cli flags