Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add account ID to the environment variable credential provider #3332

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ class Credentials:
:param str token: The security token, valid only for session credentials.
:param str method: A string which identifies where the credentials
were found.
:param str account_id: (optional) An account ID associated with the credentials.
"""

def __init__(
Expand Down Expand Up @@ -1118,6 +1119,7 @@ class EnvProvider(CredentialProvider):
# AWS_SESSION_TOKEN is what other AWS SDKs have standardized on.
TOKENS = ['AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN']
EXPIRY_TIME = 'AWS_CREDENTIAL_EXPIRATION'
ACCOUNT_ID = 'AWS_ACCOUNT_ID'

def __init__(self, environ=None, mapping=None):
"""
Expand All @@ -1127,8 +1129,12 @@ def __init__(self, environ=None, mapping=None):
:param mapping: An optional mapping of variable names to
environment variable names. Use this if you want to
change the mapping of access_key->AWS_ACCESS_KEY_ID, etc.
The dict can have up to 3 keys: ``access_key``, ``secret_key``,
``session_token``.
The dict can have up to 5 keys:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current mapping is:

{
    'access_key': 'AWS_ACCESS_KEY_ID',
    'secret_key': 'AWS_SECRET_ACCESS_KEY',
    'token': ['AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN'],
    'expiry_time': 'AWS_CREDENTIAL_EXPIRATION',
    'account_id': 'AWS_ACCOUNT_ID',
}

* ``access_key``
* ``secret_key``
* ``token``
* ``expiry_time``
* ``account_id``
"""
if environ is None:
environ = os.environ
Expand All @@ -1144,6 +1150,7 @@ def _build_mapping(self, mapping):
var_mapping['secret_key'] = self.SECRET_KEY
var_mapping['token'] = self.TOKENS
var_mapping['expiry_time'] = self.EXPIRY_TIME
var_mapping['account_id'] = self.ACCOUNT_ID
else:
var_mapping['access_key'] = mapping.get(
'access_key', self.ACCESS_KEY
Expand All @@ -1157,6 +1164,9 @@ def _build_mapping(self, mapping):
var_mapping['expiry_time'] = mapping.get(
'expiry_time', self.EXPIRY_TIME
)
var_mapping['account_id'] = mapping.get(
'account_id', self.ACCOUNT_ID
)
return var_mapping

def load(self):
Expand All @@ -1181,13 +1191,15 @@ def load(self):
expiry_time,
refresh_using=fetcher,
method=self.METHOD,
account_id=credentials['account_id'],
)

return Credentials(
credentials['access_key'],
credentials['secret_key'],
credentials['token'],
method=self.METHOD,
account_id=credentials['account_id'],
)
else:
return None
Expand Down Expand Up @@ -1230,6 +1242,11 @@ def fetch_credentials(require_expiry=True):
provider=method, cred_var=mapping['expiry_time']
)

credentials['account_id'] = None
account_id = environ.get(mapping['account_id'], '')
if account_id:
credentials['account_id'] = account_id

return credentials

return fetch_credentials
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,20 @@ def test_envvars_found_with_session_token(self):
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.method, 'env')

def test_envvars_found_with_account_id(self):
environ = {
'AWS_ACCESS_KEY_ID': 'foo',
'AWS_SECRET_ACCESS_KEY': 'bar',
'AWS_ACCOUNT_ID': 'baz',
}
provider = credentials.EnvProvider(environ)
creds = provider.load()
self.assertIsNotNone(creds)
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.account_id, 'baz')
self.assertEqual(creds.method, 'env')

def test_envvars_not_found(self):
provider = credentials.EnvProvider(environ={})
creds = provider.load()
Expand Down Expand Up @@ -1127,6 +1141,22 @@ def test_can_override_expiry_env_var_mapping(self):
with self.assertRaisesRegex(RuntimeError, error_message):
creds.get_frozen_credentials()

def test_can_override_account_id_env_var_mapping(self):
environ = {
'AWS_ACCESS_KEY_ID': 'foo',
'AWS_SECRET_ACCESS_KEY': 'bar',
'AWS_SESSION_TOKEN': 'baz',
'FOO_ACCOUNT_ID': 'bin',
}
provider = credentials.EnvProvider(
environ, {'account_id': 'FOO_ACCOUNT_ID'}
)
creds = provider.load()
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.account_id, 'bin')

def test_partial_creds_is_an_error(self):
# If the user provides an access key, they must also
# provide a secret key. Not doing so will generate an
Expand Down
Loading