Skip to content

Commit

Permalink
add timeout and retry logic to create_redis_pool
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Jun 5, 2017
1 parent 92c65ef commit 838bd58
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ v0.8.0 (2017-06-05)
* change logger name for control process log messages
* use ``Semaphore`` rather than ``asyncio.wait(...return_when=asyncio.FIRST_COMPLETED)`` for improved performance
* improve log display
* add timeout and retry logic to ``RedisMixin.create_redis_pool``

v0.7.0 (2017-06-01)
...................
Expand Down
30 changes: 26 additions & 4 deletions arq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,33 @@
"""
import asyncio
import base64
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import Tuple, Union

import aioredis
from aioredis.pool import RedisPool
from async_timeout import timeout

__all__ = ['RedisSettings', 'RedisMixin']
logger = logging.getLogger('arq.utils')


class RedisSettings:
"""
No-Op class used to hold redis connection redis_settings.
"""
__slots__ = 'host', 'port', 'database', 'password', 'conn_retries', 'conn_timeout', 'conn_retry_delay'

def __init__(self,
host='localhost',
port=6379,
database=0,
password=None):
password=None,
conn_timeout=1,
conn_retries=5,
conn_retry_delay=1):
"""
:param host: redis host
:param port: redis port
Expand All @@ -35,6 +43,9 @@ def __init__(self,
self.port = port
self.database = database
self.password = password
self.conn_timeout = conn_timeout
self.conn_retries = conn_retries
self.conn_retry_delay = conn_retry_delay


class RedisMixin:
Expand All @@ -56,12 +67,23 @@ def __init__(self, *,
self.redis_settings = redis_settings or getattr(self, 'redis_settings', None) or RedisSettings()
self._redis_pool = existing_pool

async def create_redis_pool(self) -> RedisPool:
async def create_redis_pool(self, *, _retry=0) -> RedisPool:
"""
Create a new redis pool.
"""
return await aioredis.create_pool((self.redis_settings.host, self.redis_settings.port), loop=self.loop,
db=self.redis_settings.database, password=self.redis_settings.password)
addr = self.redis_settings.host, self.redis_settings.port
try:
with timeout(self.redis_settings.conn_timeout):
return await aioredis.create_pool(addr, loop=self.loop, db=self.redis_settings.database,
password=self.redis_settings.password)
except (ConnectionError, OSError, aioredis.RedisError, asyncio.TimeoutError) as e:
if _retry < self.redis_settings.conn_retries:
logger.warning('redis connection error %s %s, %d retries remaining...',
e.__class__.__name__, e, self.redis_settings.conn_retries - _retry)
await asyncio.sleep(self.redis_settings.conn_retry_delay)
return await self.create_redis_pool(_retry=_retry + 1)
else:
raise

async def get_redis_pool(self) -> RedisPool:
"""
Expand Down
3 changes: 2 additions & 1 deletion arq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class BaseWorker(RedisMixin):
repeat_health_check_logs = False

drain_class = Drain
_shadow_factory_timeout = 10

def __init__(self, *,
burst: bool=False,
Expand Down Expand Up @@ -178,7 +179,7 @@ async def run(self):
self._stopped = False
work_logger.info('Initialising work manager, burst mode: %s, creating shadows...', self._burst_mode)

with timeout(10):
with timeout(self._shadow_factory_timeout):
shadows = await self.shadow_factory()
assert isinstance(shadows, list), 'shadow_factory should return a list not %s' % type(shadows)
self.job_class = shadows[0].job_class
Expand Down
13 changes: 12 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import os
from datetime import datetime

from arq import RedisSettings
import pytest

import arq.utils
from arq import RedisMixin, RedisSettings
from arq.logs import ColourHandler
from arq.testing import MockRedis
from arq.utils import timestamp
Expand Down Expand Up @@ -49,3 +52,11 @@ async def test_mock_redis_flushdb(loop):
assert 'bar' == await r.get('foo')
await r.flushdb()
assert None is await r.get('foo')


async def test_redis_timeout(loop, mocker):
mocker.spy(arq.utils.asyncio, 'sleep')
r = RedisMixin(redis_settings=RedisSettings(port=0, conn_retry_delay=0), loop=loop)
with pytest.raises(OSError):
await r.get_redis_pool()
assert arq.utils.asyncio.sleep.call_count == 5

0 comments on commit 838bd58

Please sign in to comment.