diff --git a/cmd/login.go b/cmd/login.go index 1f52017..3a494bb 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -106,13 +106,20 @@ var loginCmd = &cobra.Command{ fmt.Println(err) os.Exit(1) } - intAWS.SetCredentials( + err = intAWS.SetCredentials( + &intAWS.CredentialFileGetter{}, + &intAWS.CredentialFileWriter{}, *result.Credentials.AccessKeyId, *result.Credentials.SecretAccessKey, *result.Credentials.SessionToken, config.DefaultRegion, profileName, ) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + fmt.Printf("Successfully set credentials for: %s\n", profileName) }, } diff --git a/internal/aws/credential-manager.go b/internal/aws/credential-manager.go index 67d8a4d..665a1e1 100644 --- a/internal/aws/credential-manager.go +++ b/internal/aws/credential-manager.go @@ -1,47 +1,151 @@ package aws import ( + "bufio" + "errors" "fmt" - "log" - "os/exec" + "io/ioutil" + "os" + "strings" ) -func SetCredentials(accessKey, secretAccessKey, sessionToken, region, profileName string) { - err := setProfileParameter("aws_access_key_id", accessKey, profileName) +type CredentialFileGetterIface interface { + Get() (string, error) + Exists() (bool, error) + SetHomeDir(string) +} +type CredentialFileWriterIface interface { + Write(b []byte) error + SetHomeDir(string) +} + +type CredentialFileGetter struct { + Homedir string +} + +func (c *CredentialFileGetter) SetHomeDir(homedir string) { + c.Homedir = homedir +} +func (c *CredentialFileGetter) Get() (string, error) { + credentials, err := ioutil.ReadFile(c.Homedir + "/.aws/credentials") if err != nil { - log.Fatalln("Error while writing credentials to file. ", err) + return "", fmt.Errorf("Could not read file: %s", err) } - err = setProfileParameter("aws_secret_access_key", secretAccessKey, profileName) - if err != nil { - log.Fatalln("Error while writing credentials to file. ", err) + + return string(credentials), nil +} +func (c *CredentialFileGetter) Exists() (bool, error) { + credentialsFileExists := false + + if _, err := os.Stat(c.Homedir + "/.aws/credentials"); err == nil { + credentialsFileExists = true + } else if errors.Is(err, os.ErrNotExist) { + credentialsFileExists = false + } else { + return credentialsFileExists, fmt.Errorf("Could not determine if credential file exists: %s", err) } - err = setProfileParameter("aws_session_token", sessionToken, profileName) + return credentialsFileExists, nil +} + +type CredentialFileWriter struct { + Homedir string +} + +func (c *CredentialFileWriter) Write(b []byte) error { + return os.WriteFile(c.Homedir+"/.aws/credentials", b, 0600) +} +func (c *CredentialFileWriter) SetHomeDir(homedir string) { + c.Homedir = homedir +} + +func SetCredentials(credFileGetter CredentialFileGetterIface, credFileWriter CredentialFileWriterIface, accessKey, secretAccessKey, sessionToken, region, profileName string) error { + // set user homedir + homedir, err := os.UserHomeDir() if err != nil { - log.Fatalln("Error while writing credentials to file. ", err) + return fmt.Errorf("could not get user's home directory: %s", err) } - err = setProfileParameter("region", region, profileName) + credFileGetter.SetHomeDir(homedir) + credFileWriter.SetHomeDir(homedir) + + // check whether file exists + exists, err := credFileGetter.Exists() if err != nil { - log.Fatalln("Error while writing credentials to file. ", err) + return fmt.Errorf("Couldn't determine whether credential file exists: %s", err) + } + + // if exists, update / write file + if exists { + credentials, err := credFileGetter.Get() + if err != nil { + return fmt.Errorf("Couldn't get credential file: %s", err) + } + if profileExists(credentials, profileName) { + err = updateCredential(credFileWriter, credentials, accessKey, secretAccessKey, sessionToken, region, profileName) + if err != nil { + return fmt.Errorf("updateCredential error: %s", err) + } + } else { + err = appendCredential(credFileWriter, credentials, accessKey, secretAccessKey, sessionToken, region, profileName) + if err != nil { + return fmt.Errorf("appendCredential error: %s", err) + } + } + } else { // write new file if credential file doesn't exist + err = writeCredentialAsNewFile(credFileWriter, accessKey, secretAccessKey, sessionToken, region, profileName) + if err != nil { + return fmt.Errorf("writeCredentialAsNewFile error: %s", err) + } } + + return nil } -func setProfileParameter(parameter, value, profileName string) error { - command := []string{ - "aws", - "configure", - "set", - parameter, - value, - "--profile", - profileName, +func profileExists(credentials string, profileName string) bool { + scanner := bufio.NewScanner(strings.NewReader(credentials)) + for scanner.Scan() { + if strings.Trim(scanner.Text(), " ") == "["+profileName+"]" { + return true + } } + return false +} - cmd := exec.Command(command[0], command[1], command[2], command[3], command[4], command[5], command[6]) - _, err := cmd.Output() +func updateCredential(writer CredentialFileWriterIface, credentials, accessKey, secretAccessKey, sessionToken, region, profileName string) error { + newCredentials := "" + found := false + scanner := bufio.NewScanner(strings.NewReader(credentials)) + for scanner.Scan() { + line := scanner.Text() + if strings.Trim(line, " ") == "["+profileName+"]" { + found = true + newCredentials += formatCredential(accessKey, secretAccessKey, sessionToken, region, profileName) + } - if err != nil { - fmt.Println(err.Error()) - return err + if !found { + newCredentials += line + "\n" + } + + if found && line == "" { + found = false + newCredentials += "\n\n" + } } - return nil + newCredentials = strings.TrimRight(newCredentials, "\n") + return writer.Write([]byte(newCredentials)) +} +func appendCredential(writer CredentialFileWriterIface, credentials, accessKey, secretAccessKey, sessionToken, region, profileName string) error { + credentials = strings.TrimRight(credentials, "\n") + if credentials != "" { + credentials += "\n\n" + } + credentials += formatCredential(accessKey, secretAccessKey, sessionToken, region, profileName) + return writer.Write([]byte(credentials)) +} + +func writeCredentialAsNewFile(writer CredentialFileWriterIface, accessKey, secretAccessKey, sessionToken, region, profileName string) error { + return writer.Write([]byte(formatCredential(accessKey, secretAccessKey, sessionToken, region, profileName))) +} + +func formatCredential(accessKey, secretAccessKey, sessionToken, region, profileName string) string { + return fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s", profileName, accessKey, secretAccessKey, sessionToken) } diff --git a/internal/aws/credential-manager_test.go b/internal/aws/credential-manager_test.go new file mode 100644 index 0000000..c54cd98 --- /dev/null +++ b/internal/aws/credential-manager_test.go @@ -0,0 +1,172 @@ +package aws + +import ( + "testing" +) + +type MockCredentialFileGetter struct { + CredentialOutput string + CredentialExists bool +} +type MockCredentialFileWriter struct { + Out []byte +} + +func (c MockCredentialFileGetter) SetHomeDir(homedir string) { + +} +func (c MockCredentialFileGetter) Get() (string, error) { + return c.CredentialOutput, nil +} +func (c MockCredentialFileGetter) Exists() (bool, error) { + return c.CredentialExists, nil +} + +func (c MockCredentialFileWriter) SetHomeDir(homedir string) { + +} +func (c *MockCredentialFileWriter) Write(b []byte) error { + c.Out = b + return nil +} + +func TestSetCredentialsEmpty(t *testing.T) { + mockGetter := MockCredentialFileGetter{ + CredentialOutput: "", + CredentialExists: true, + } + mockWriter := MockCredentialFileWriter{} + err := SetCredentials( + mockGetter, + &mockWriter, + "accessKey", + "secretKey", + "sessionToken", + "us-eas-1", + "myProfile", + ) + if err != nil { + t.Errorf("error: %s", err) + } + if string(mockWriter.Out) != formatCredential("accessKey", "secretKey", "sessionToken", "us-eas-1", "myProfile") { + t.Errorf("Wrong output. Got: %s", mockWriter.Out) + } +} + +func TestSetCredentialsNonExisting(t *testing.T) { + mockGetter := MockCredentialFileGetter{ + CredentialOutput: "", + CredentialExists: false, + } + mockWriter := MockCredentialFileWriter{} + err := SetCredentials( + mockGetter, + &mockWriter, + "accessKey", + "secretKey", + "sessionToken", + "us-eas-1", + "myProfile", + ) + if err != nil { + t.Errorf("error: %s", err) + } + if string(mockWriter.Out) != formatCredential("accessKey", "secretKey", "sessionToken", "us-eas-1", "myProfile") { + t.Errorf("Wrong output. Got: %s", mockWriter.Out) + } +} +func TestSetCredentialsExistingFile(t *testing.T) { + mockGetter := MockCredentialFileGetter{ + CredentialOutput: formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123"), + CredentialExists: true, + } + mockWriter := MockCredentialFileWriter{} + err := SetCredentials( + mockGetter, + &mockWriter, + "accessKey", + "secretKey", + "sessionToken", + "us-east-1", + "myProfile", + ) + if err != nil { + t.Errorf("error: %s", err) + } + expected := formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey", "secretKey", "sessionToken", "us-east-1", "myProfile") + if string(mockWriter.Out) != expected { + t.Errorf("Wrong output. Got: %s", mockWriter.Out) + } +} + +func TestSetCredentialsExistingFileReplace1(t *testing.T) { + mockGetter := MockCredentialFileGetter{ + CredentialOutput: formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey2", "secretKey2", "sessionToken2", "us-east-2", "myProfile"), + CredentialExists: true, + } + mockWriter := MockCredentialFileWriter{} + err := SetCredentials( + mockGetter, + &mockWriter, + "accessKey", + "secretKey", + "sessionToken", + "us-east-1", + "myProfile", + ) + if err != nil { + t.Errorf("error: %s", err) + } + expected := formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey", "secretKey", "sessionToken", "us-east-1", "myProfile") + if string(mockWriter.Out) != expected { + t.Errorf("Wrong output. Got: %s", mockWriter.Out) + } +} + +func TestSetCredentialsExistingFileReplace2(t *testing.T) { + mockGetter := MockCredentialFileGetter{ + CredentialOutput: formatCredential("accessKey2", "secretKey2", "sessionToken2", "us-east-2", "myProfile") + "\n\n" + formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123"), + CredentialExists: true, + } + mockWriter := MockCredentialFileWriter{} + err := SetCredentials( + mockGetter, + &mockWriter, + "accessKey", + "secretKey", + "sessionToken", + "us-east-1", + "myProfile", + ) + if err != nil { + t.Errorf("error: %s", err) + } + expected := formatCredential("accessKey", "secretKey", "sessionToken", "us-east-1", "myProfile") + "\n\n" + formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + if string(mockWriter.Out) != expected { + t.Errorf("Wrong output. Got: %s\n\nExpected: %s", mockWriter.Out, expected) + } +} + +func TestSetCredentialsExistingFileReplace3(t *testing.T) { + mockGetter := MockCredentialFileGetter{ + CredentialOutput: formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey2", "secretKey2", "sessionToken2", "us-east-2", "myProfile") + "\n\n" + formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile1234"), + CredentialExists: true, + } + mockWriter := MockCredentialFileWriter{} + err := SetCredentials( + mockGetter, + &mockWriter, + "accessKey", + "secretKey", + "sessionToken", + "us-east-1", + "myProfile", + ) + if err != nil { + t.Errorf("error: %s", err) + } + expected := formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey", "secretKey", "sessionToken", "us-east-1", "myProfile") + "\n\n" + formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile1234") + if string(mockWriter.Out) != expected { + t.Errorf("Wrong output. Got: %s\n\nExpected: %s", mockWriter.Out, expected) + } +}