Skip to content

Commit

Permalink
Merge pull request #3 from zalando-stups/2-use-valid-token-on-refresh…
Browse files Browse the repository at this point in the history
…-failure

#2 return still valid token on failing refresh
  • Loading branch information
hjacobs committed Feb 3, 2016
2 parents 93487b2 + 91292f4 commit 4450c15
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 8 deletions.
113 changes: 113 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,58 @@
import tokens
from mock import MagicMock


VALID_USER_JSON = {'application_username': 'app', 'application_password': 'pass'}
VALID_CLIENT_JSON = {'client_id': 'cid', 'client_secret': 'sec'}


def test_init_fixed_tokens_from_env(monkeypatch):
monkeypatch.setattr('os.environ', {'OAUTH2_ACCESS_TOKENS': 'mytok=123,t2=3'})
tokens.init_fixed_tokens_from_env()
assert '123' == tokens.get('mytok')
assert '3' == tokens.get('t2')


def test_read_credentials(tmpdir):
path = str(tmpdir)
user = VALID_USER_JSON
client = VALID_CLIENT_JSON
with open(os.path.join(path, 'user.json'), 'w') as fd:
json.dump(user, fd)

with open(os.path.join(path, 'client.json'), 'w') as fd:
json.dump(client, fd)

assert (user, client) == tokens.read_credentials(path)

with open(os.path.join(path, 'client.json'), 'w') as fd:
fd.write('invalid')

with pytest.raises(tokens.InvalidCredentialsError) as exc_info:
tokens.read_credentials(path)

with open(os.path.join(path, 'user.json'), 'w') as fd:
fd.write('invalid')

with pytest.raises(tokens.InvalidCredentialsError) as exc_info:
tokens.read_credentials(path)


def test_get():
tokens.TOKENS = {'test': {'access_token': 'mytok123',
'expires_at': time.time() + 3600}}
tokens.get('test')


def test_refresh_without_configuration():
# remove URL config
tokens.configure(dir='', url='')
tokens.manage('mytok', ['scope'])
with pytest.raises(tokens.ConfigurationError) as exc_info:
tokens.refresh('mytok')
assert str(exc_info.value) == 'Configuration error: Missing OAuth access token URL. Either set OAUTH2_ACCESS_TOKEN_URL or use tokens.configure(url=..).'


def test_refresh(monkeypatch, tmpdir):
tokens.configure(dir=str(tmpdir), url='')
tokens.manage('mytok', ['myscope'])
Expand All @@ -30,3 +76,70 @@ def test_refresh(monkeypatch, tmpdir):
monkeypatch.setattr('requests.post', lambda url, **kwargs: response)
tok = tokens.get('mytok')
assert tok == '777'


def test_refresh_invalid_credentials(monkeypatch, tmpdir):
tokens.configure(dir=str(tmpdir), url='https://example.org')
tokens.manage('mytok', ['myscope'])
tokens.start() # this does not do anything..

with open(os.path.join(str(tmpdir), 'user.json'), 'w') as fd:
# missing password
json.dump({'application_username': 'app'}, fd)

with open(os.path.join(str(tmpdir), 'client.json'), 'w') as fd:
json.dump({'client_id': 'cid', 'client_secret': 'sec'}, fd)

with pytest.raises(tokens.InvalidCredentialsError) as exc_info:
tokens.get('mytok')
assert str(exc_info.value) == "Invalid OAuth credentials: Missing key: 'application_password'"


def test_refresh_invalid_response(monkeypatch, tmpdir):
tokens.configure(dir=str(tmpdir), url='https://example.org')
tokens.manage('mytok', ['myscope'])
tokens.start() # this does not do anything..

response = MagicMock()
response.json.return_value = {'foo': 'bar'}
monkeypatch.setattr('requests.post', lambda url, **kwargs: response)
monkeypatch.setattr('tokens.read_credentials', lambda path: (VALID_USER_JSON, VALID_CLIENT_JSON))

with pytest.raises(tokens.InvalidTokenResponse) as exc_info:
tokens.get('mytok')
assert str(exc_info.value) == """Invalid token response: Expected a JSON object with keys "expires_in" and "access_token": 'expires_in'"""

response.json.return_value = {'access_token': '', 'expires_in': 100}
with pytest.raises(tokens.InvalidTokenResponse) as exc_info:
tokens.get('mytok')
assert str(exc_info.value) == 'Invalid token response: Empty "access_token" value'


def test_get_refresh_failure(monkeypatch, tmpdir):
tokens.configure(dir=str(tmpdir), url='https://example.org')

with open(os.path.join(str(tmpdir), 'user.json'), 'w') as fd:
json.dump({'application_username': 'app', 'application_password': 'pass'}, fd)

with open(os.path.join(str(tmpdir), 'client.json'), 'w') as fd:
json.dump({'client_id': 'cid', 'client_secret': 'sec'}, fd)

exc = Exception('FAIL')
response = MagicMock()
response.raise_for_status.side_effect = exc
monkeypatch.setattr('requests.post', lambda url, **kwargs: response)
logger = MagicMock()
monkeypatch.setattr('tokens.logger', logger)
tokens.TOKENS = {'mytok': {'access_token': 'oldtok',
'scopes': ['myscope'],
# token is still valid for 10 minutes
'expires_at': time.time() + (10 * 60)}}
tok = tokens.get('mytok')
assert tok == 'oldtok'
logger.warn.assert_called_with('Failed to refresh access token "%s" (but it is still valid): %s', 'mytok', exc)

tokens.TOKENS = {'mytok': {'scopes': ['myscope']}}
with pytest.raises(Exception) as exc_info:
tok = tokens.get('mytok')
assert exc_info.value == exc

38 changes: 30 additions & 8 deletions tokens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
logger = logging.getLogger('tokens')

ONE_YEAR = 3600*24*365
EXPIRATION_TOLERANCE_SECS = 60
# TODO: make time value configurable (20 minutes)?
REFRESH_BEFORE_SECS_LEFT = 20 * 60

CONFIG = {'url': os.environ.get('OAUTH2_ACCESS_TOKEN_URL', os.environ.get('OAUTH_ACCESS_TOKEN_URL')),
'dir': os.environ.get('CREDENTIALS_DIR', '')}
Expand All @@ -32,6 +35,14 @@ def __str__(self):
return 'Invalid OAuth credentials: {}'.format(self.msg)


class InvalidTokenResponse(Exception):
def __init__(self, msg):
self.msg = msg

def __str__(self):
return 'Invalid token response: {}'.format(self.msg)


def init_fixed_tokens_from_env():
env_val = os.environ.get('OAUTH2_ACCESS_TOKENS', '')
for part in filter(None, env_val.split(',')):
Expand Down Expand Up @@ -96,19 +107,30 @@ def refresh(token_name):

r = requests.post(url, data=body, auth=auth)
r.raise_for_status()
data = r.json()
token['data'] = data
token['expires_at'] = time.time() + data.get('expires_in')
token['access_token'] = data.get('access_token')
try:
data = r.json()
token['data'] = data
token['expires_at'] = time.time() + data['expires_in']
token['access_token'] = data['access_token']
except Exception as e:
raise InvalidTokenResponse('Expected a JSON object with keys "expires_in" and "access_token": {}'.format(e))
if not token['access_token']:
raise InvalidTokenResponse('Empty "access_token" value')
return token


def get(token_name):
token = TOKENS[token_name]
access_token = token.get('access_token')
# TODO: remove hardcoded time value (20 minutes)
if not access_token or time.time() > token['expires_at'] - 20*60:
token = refresh(token_name)
if not access_token or time.time() > token['expires_at'] - REFRESH_BEFORE_SECS_LEFT:
try:
token = refresh(token_name)
access_token = token.get('access_token')
except Exception as e:
if access_token and time.time() < token['expires_at'] + EXPIRATION_TOLERANCE_SECS:
# apply some tolerance, still try our old token if it's still valid
logger.warn('Failed to refresh access token "%s" (but it is still valid): %s', token_name, e)
else:
raise

access_token = token.get('access_token')
return access_token

0 comments on commit 4450c15

Please sign in to comment.