From 729506800effff176027dcb3c784f69434e5e71c Mon Sep 17 00:00:00 2001 From: Alessandra Romero Date: Wed, 8 Jan 2025 17:20:55 -0500 Subject: [PATCH] Add account ID to the environment credentials provider --- botocore/credentials.py | 21 +++++++++++++++++++-- tests/unit/test_credentials.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/botocore/credentials.py b/botocore/credentials.py index dd7e718255..90c7829760 100644 --- a/botocore/credentials.py +++ b/botocore/credentials.py @@ -310,6 +310,7 @@ class Credentials: :param str token: The security token, valid only for session credentials. :param str method: A string which identifies where the credentials were found. + :param str account_id: The account ID associated with the credentials. """ def __init__(self, access_key, secret_key, token=None, method=None): @@ -1095,6 +1096,7 @@ class EnvProvider(CredentialProvider): # AWS_SESSION_TOKEN is what other AWS SDKs have standardized on. TOKENS = ['AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN'] EXPIRY_TIME = 'AWS_CREDENTIAL_EXPIRATION' + ACCOUNT_ID = 'AWS_ACCOUNT_ID' def __init__(self, environ=None, mapping=None): """ @@ -1104,8 +1106,12 @@ def __init__(self, environ=None, mapping=None): :param mapping: An optional mapping of variable names to environment variable names. Use this if you want to change the mapping of access_key->AWS_ACCESS_KEY_ID, etc. - The dict can have up to 3 keys: ``access_key``, ``secret_key``, - ``session_token``. + The dict can have up to 5 keys: + * ``access_key`` + * ``secret_key`` + * ``token`` + * ``expiry_time`` + * ``account_id`` """ if environ is None: environ = os.environ @@ -1121,6 +1127,7 @@ def _build_mapping(self, mapping): var_mapping['secret_key'] = self.SECRET_KEY var_mapping['token'] = self.TOKENS var_mapping['expiry_time'] = self.EXPIRY_TIME + var_mapping['account_id'] = self.ACCOUNT_ID else: var_mapping['access_key'] = mapping.get( 'access_key', self.ACCESS_KEY @@ -1134,6 +1141,9 @@ def _build_mapping(self, mapping): var_mapping['expiry_time'] = mapping.get( 'expiry_time', self.EXPIRY_TIME ) + var_mapping['account_id'] = mapping.get( + 'account_id', self.ACCOUNT_ID + ) return var_mapping def load(self): @@ -1158,6 +1168,7 @@ def load(self): expiry_time, refresh_using=fetcher, method=self.METHOD, + account_id=credentials['account_id'], ) return Credentials( @@ -1165,6 +1176,7 @@ def load(self): credentials['secret_key'], credentials['token'], method=self.METHOD, + account_id=credentials['account_id'], ) else: return None @@ -1207,6 +1219,11 @@ def fetch_credentials(require_expiry=True): provider=method, cred_var=mapping['expiry_time'] ) + credentials['account_id'] = None + account_id = environ.get(mapping['account_id'], '') + if account_id: + credentials['account_id'] = account_id + return credentials return fetch_credentials diff --git a/tests/unit/test_credentials.py b/tests/unit/test_credentials.py index 3017444a6f..28608eb5cf 100644 --- a/tests/unit/test_credentials.py +++ b/tests/unit/test_credentials.py @@ -1020,6 +1020,20 @@ def test_envvars_found_with_session_token(self): self.assertEqual(creds.token, 'baz') self.assertEqual(creds.method, 'env') + def test_envvars_found_with_account_id(self): + environ = { + 'AWS_ACCESS_KEY_ID': 'foo', + 'AWS_SECRET_ACCESS_KEY': 'bar', + 'AWS_ACCOUNT_ID': 'baz', + } + provider = credentials.EnvProvider(environ) + creds = provider.load() + self.assertIsNotNone(creds) + self.assertEqual(creds.access_key, 'foo') + self.assertEqual(creds.secret_key, 'bar') + self.assertEqual(creds.account_id, 'baz') + self.assertEqual(creds.method, 'env') + def test_envvars_not_found(self): provider = credentials.EnvProvider(environ={}) creds = provider.load() @@ -1127,6 +1141,22 @@ def test_can_override_expiry_env_var_mapping(self): with self.assertRaisesRegex(RuntimeError, error_message): creds.get_frozen_credentials() + def test_can_override_account_id_env_var_mapping(self): + environ = { + 'AWS_ACCESS_KEY_ID': 'foo', + 'AWS_SECRET_ACCESS_KEY': 'bar', + 'AWS_SESSION_TOKEN': 'baz', + 'FOO_ACCOUNT_ID': 'bin', + } + provider = credentials.EnvProvider( + environ, {'account_id': 'FOO_ACCOUNT_ID'} + ) + creds = provider.load() + self.assertEqual(creds.access_key, 'foo') + self.assertEqual(creds.secret_key, 'bar') + self.assertEqual(creds.token, 'baz') + self.assertEqual(creds.account_id, 'bin') + def test_partial_creds_is_an_error(self): # If the user provides an access key, they must also # provide a secret key. Not doing so will generate an