-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
3 changed files
with
311 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,151 @@ | ||
package aws | ||
|
||
import ( | ||
"bufio" | ||
"errors" | ||
"fmt" | ||
"log" | ||
"os/exec" | ||
"io/ioutil" | ||
"os" | ||
"strings" | ||
) | ||
|
||
func SetCredentials(accessKey, secretAccessKey, sessionToken, region, profileName string) { | ||
err := setProfileParameter("aws_access_key_id", accessKey, profileName) | ||
type CredentialFileGetterIface interface { | ||
Get() (string, error) | ||
Exists() (bool, error) | ||
SetHomeDir(string) | ||
} | ||
type CredentialFileWriterIface interface { | ||
Write(b []byte) error | ||
SetHomeDir(string) | ||
} | ||
|
||
type CredentialFileGetter struct { | ||
Homedir string | ||
} | ||
|
||
func (c *CredentialFileGetter) SetHomeDir(homedir string) { | ||
c.Homedir = homedir | ||
} | ||
func (c *CredentialFileGetter) Get() (string, error) { | ||
credentials, err := ioutil.ReadFile(c.Homedir + "/.aws/credentials") | ||
if err != nil { | ||
log.Fatalln("Error while writing credentials to file. ", err) | ||
return "", fmt.Errorf("Could not read file: %s", err) | ||
} | ||
err = setProfileParameter("aws_secret_access_key", secretAccessKey, profileName) | ||
if err != nil { | ||
log.Fatalln("Error while writing credentials to file. ", err) | ||
|
||
return string(credentials), nil | ||
} | ||
func (c *CredentialFileGetter) Exists() (bool, error) { | ||
credentialsFileExists := false | ||
|
||
if _, err := os.Stat(c.Homedir + "/.aws/credentials"); err == nil { | ||
credentialsFileExists = true | ||
} else if errors.Is(err, os.ErrNotExist) { | ||
credentialsFileExists = false | ||
} else { | ||
return credentialsFileExists, fmt.Errorf("Could not determine if credential file exists: %s", err) | ||
} | ||
err = setProfileParameter("aws_session_token", sessionToken, profileName) | ||
return credentialsFileExists, nil | ||
} | ||
|
||
type CredentialFileWriter struct { | ||
Homedir string | ||
} | ||
|
||
func (c *CredentialFileWriter) Write(b []byte) error { | ||
return os.WriteFile(c.Homedir+"/.aws/credentials", b, 0600) | ||
} | ||
func (c *CredentialFileWriter) SetHomeDir(homedir string) { | ||
c.Homedir = homedir | ||
} | ||
|
||
func SetCredentials(credFileGetter CredentialFileGetterIface, credFileWriter CredentialFileWriterIface, accessKey, secretAccessKey, sessionToken, region, profileName string) error { | ||
// set user homedir | ||
homedir, err := os.UserHomeDir() | ||
if err != nil { | ||
log.Fatalln("Error while writing credentials to file. ", err) | ||
return fmt.Errorf("could not get user's home directory: %s", err) | ||
} | ||
err = setProfileParameter("region", region, profileName) | ||
credFileGetter.SetHomeDir(homedir) | ||
credFileWriter.SetHomeDir(homedir) | ||
|
||
// check whether file exists | ||
exists, err := credFileGetter.Exists() | ||
if err != nil { | ||
log.Fatalln("Error while writing credentials to file. ", err) | ||
return fmt.Errorf("Couldn't determine whether credential file exists: %s", err) | ||
} | ||
|
||
// if exists, update / write file | ||
if exists { | ||
credentials, err := credFileGetter.Get() | ||
if err != nil { | ||
return fmt.Errorf("Couldn't get credential file: %s", err) | ||
} | ||
if profileExists(credentials, profileName) { | ||
err = updateCredential(credFileWriter, credentials, accessKey, secretAccessKey, sessionToken, region, profileName) | ||
if err != nil { | ||
return fmt.Errorf("updateCredential error: %s", err) | ||
} | ||
} else { | ||
err = appendCredential(credFileWriter, credentials, accessKey, secretAccessKey, sessionToken, region, profileName) | ||
if err != nil { | ||
return fmt.Errorf("appendCredential error: %s", err) | ||
} | ||
} | ||
} else { // write new file if credential file doesn't exist | ||
err = writeCredentialAsNewFile(credFileWriter, accessKey, secretAccessKey, sessionToken, region, profileName) | ||
if err != nil { | ||
return fmt.Errorf("writeCredentialAsNewFile error: %s", err) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func setProfileParameter(parameter, value, profileName string) error { | ||
command := []string{ | ||
"aws", | ||
"configure", | ||
"set", | ||
parameter, | ||
value, | ||
"--profile", | ||
profileName, | ||
func profileExists(credentials string, profileName string) bool { | ||
scanner := bufio.NewScanner(strings.NewReader(credentials)) | ||
for scanner.Scan() { | ||
if strings.Trim(scanner.Text(), " ") == "["+profileName+"]" { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
cmd := exec.Command(command[0], command[1], command[2], command[3], command[4], command[5], command[6]) | ||
_, err := cmd.Output() | ||
func updateCredential(writer CredentialFileWriterIface, credentials, accessKey, secretAccessKey, sessionToken, region, profileName string) error { | ||
newCredentials := "" | ||
found := false | ||
scanner := bufio.NewScanner(strings.NewReader(credentials)) | ||
for scanner.Scan() { | ||
line := scanner.Text() | ||
if strings.Trim(line, " ") == "["+profileName+"]" { | ||
found = true | ||
newCredentials += formatCredential(accessKey, secretAccessKey, sessionToken, region, profileName) | ||
} | ||
|
||
if err != nil { | ||
fmt.Println(err.Error()) | ||
return err | ||
if !found { | ||
newCredentials += line + "\n" | ||
} | ||
|
||
if found && line == "" { | ||
found = false | ||
newCredentials += "\n\n" | ||
} | ||
} | ||
return nil | ||
newCredentials = strings.TrimRight(newCredentials, "\n") | ||
return writer.Write([]byte(newCredentials)) | ||
} | ||
func appendCredential(writer CredentialFileWriterIface, credentials, accessKey, secretAccessKey, sessionToken, region, profileName string) error { | ||
credentials = strings.TrimRight(credentials, "\n") | ||
if credentials != "" { | ||
credentials += "\n\n" | ||
} | ||
credentials += formatCredential(accessKey, secretAccessKey, sessionToken, region, profileName) | ||
return writer.Write([]byte(credentials)) | ||
} | ||
|
||
func writeCredentialAsNewFile(writer CredentialFileWriterIface, accessKey, secretAccessKey, sessionToken, region, profileName string) error { | ||
return writer.Write([]byte(formatCredential(accessKey, secretAccessKey, sessionToken, region, profileName))) | ||
} | ||
|
||
func formatCredential(accessKey, secretAccessKey, sessionToken, region, profileName string) string { | ||
return fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s", profileName, accessKey, secretAccessKey, sessionToken) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
package aws | ||
|
||
import ( | ||
"testing" | ||
) | ||
|
||
type MockCredentialFileGetter struct { | ||
CredentialOutput string | ||
CredentialExists bool | ||
} | ||
type MockCredentialFileWriter struct { | ||
Out []byte | ||
} | ||
|
||
func (c MockCredentialFileGetter) SetHomeDir(homedir string) { | ||
|
||
} | ||
func (c MockCredentialFileGetter) Get() (string, error) { | ||
return c.CredentialOutput, nil | ||
} | ||
func (c MockCredentialFileGetter) Exists() (bool, error) { | ||
return c.CredentialExists, nil | ||
} | ||
|
||
func (c MockCredentialFileWriter) SetHomeDir(homedir string) { | ||
|
||
} | ||
func (c *MockCredentialFileWriter) Write(b []byte) error { | ||
c.Out = b | ||
return nil | ||
} | ||
|
||
func TestSetCredentialsEmpty(t *testing.T) { | ||
mockGetter := MockCredentialFileGetter{ | ||
CredentialOutput: "", | ||
CredentialExists: true, | ||
} | ||
mockWriter := MockCredentialFileWriter{} | ||
err := SetCredentials( | ||
mockGetter, | ||
&mockWriter, | ||
"accessKey", | ||
"secretKey", | ||
"sessionToken", | ||
"us-eas-1", | ||
"myProfile", | ||
) | ||
if err != nil { | ||
t.Errorf("error: %s", err) | ||
} | ||
if string(mockWriter.Out) != formatCredential("accessKey", "secretKey", "sessionToken", "us-eas-1", "myProfile") { | ||
t.Errorf("Wrong output. Got: %s", mockWriter.Out) | ||
} | ||
} | ||
|
||
func TestSetCredentialsNonExisting(t *testing.T) { | ||
mockGetter := MockCredentialFileGetter{ | ||
CredentialOutput: "", | ||
CredentialExists: false, | ||
} | ||
mockWriter := MockCredentialFileWriter{} | ||
err := SetCredentials( | ||
mockGetter, | ||
&mockWriter, | ||
"accessKey", | ||
"secretKey", | ||
"sessionToken", | ||
"us-eas-1", | ||
"myProfile", | ||
) | ||
if err != nil { | ||
t.Errorf("error: %s", err) | ||
} | ||
if string(mockWriter.Out) != formatCredential("accessKey", "secretKey", "sessionToken", "us-eas-1", "myProfile") { | ||
t.Errorf("Wrong output. Got: %s", mockWriter.Out) | ||
} | ||
} | ||
func TestSetCredentialsExistingFile(t *testing.T) { | ||
mockGetter := MockCredentialFileGetter{ | ||
CredentialOutput: formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123"), | ||
CredentialExists: true, | ||
} | ||
mockWriter := MockCredentialFileWriter{} | ||
err := SetCredentials( | ||
mockGetter, | ||
&mockWriter, | ||
"accessKey", | ||
"secretKey", | ||
"sessionToken", | ||
"us-east-1", | ||
"myProfile", | ||
) | ||
if err != nil { | ||
t.Errorf("error: %s", err) | ||
} | ||
expected := formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey", "secretKey", "sessionToken", "us-east-1", "myProfile") | ||
if string(mockWriter.Out) != expected { | ||
t.Errorf("Wrong output. Got: %s", mockWriter.Out) | ||
} | ||
} | ||
|
||
func TestSetCredentialsExistingFileReplace1(t *testing.T) { | ||
mockGetter := MockCredentialFileGetter{ | ||
CredentialOutput: formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey2", "secretKey2", "sessionToken2", "us-east-2", "myProfile"), | ||
CredentialExists: true, | ||
} | ||
mockWriter := MockCredentialFileWriter{} | ||
err := SetCredentials( | ||
mockGetter, | ||
&mockWriter, | ||
"accessKey", | ||
"secretKey", | ||
"sessionToken", | ||
"us-east-1", | ||
"myProfile", | ||
) | ||
if err != nil { | ||
t.Errorf("error: %s", err) | ||
} | ||
expected := formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey", "secretKey", "sessionToken", "us-east-1", "myProfile") | ||
if string(mockWriter.Out) != expected { | ||
t.Errorf("Wrong output. Got: %s", mockWriter.Out) | ||
} | ||
} | ||
|
||
func TestSetCredentialsExistingFileReplace2(t *testing.T) { | ||
mockGetter := MockCredentialFileGetter{ | ||
CredentialOutput: formatCredential("accessKey2", "secretKey2", "sessionToken2", "us-east-2", "myProfile") + "\n\n" + formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123"), | ||
CredentialExists: true, | ||
} | ||
mockWriter := MockCredentialFileWriter{} | ||
err := SetCredentials( | ||
mockGetter, | ||
&mockWriter, | ||
"accessKey", | ||
"secretKey", | ||
"sessionToken", | ||
"us-east-1", | ||
"myProfile", | ||
) | ||
if err != nil { | ||
t.Errorf("error: %s", err) | ||
} | ||
expected := formatCredential("accessKey", "secretKey", "sessionToken", "us-east-1", "myProfile") + "\n\n" + formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") | ||
if string(mockWriter.Out) != expected { | ||
t.Errorf("Wrong output. Got: %s\n\nExpected: %s", mockWriter.Out, expected) | ||
} | ||
} | ||
|
||
func TestSetCredentialsExistingFileReplace3(t *testing.T) { | ||
mockGetter := MockCredentialFileGetter{ | ||
CredentialOutput: formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey2", "secretKey2", "sessionToken2", "us-east-2", "myProfile") + "\n\n" + formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile1234"), | ||
CredentialExists: true, | ||
} | ||
mockWriter := MockCredentialFileWriter{} | ||
err := SetCredentials( | ||
mockGetter, | ||
&mockWriter, | ||
"accessKey", | ||
"secretKey", | ||
"sessionToken", | ||
"us-east-1", | ||
"myProfile", | ||
) | ||
if err != nil { | ||
t.Errorf("error: %s", err) | ||
} | ||
expected := formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile123") + "\n\n" + formatCredential("accessKey", "secretKey", "sessionToken", "us-east-1", "myProfile") + "\n\n" + formatCredential("accessKey1", "secretKey1", "sessionToken1", "us-east-2", "profile1234") | ||
if string(mockWriter.Out) != expected { | ||
t.Errorf("Wrong output. Got: %s\n\nExpected: %s", mockWriter.Out, expected) | ||
} | ||
} |