Skip to content

Commit

Permalink
Redis Streams for immediate task delivery (python-arq#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
RB387 authored and rossmacarthur committed Dec 21, 2024
1 parent 2f752e2 commit 00ea322
Show file tree
Hide file tree
Showing 8 changed files with 571 additions and 100 deletions.
74 changes: 67 additions & 7 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,16 @@
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import RedisError, WatchError

from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix
from .constants import (
default_queue_name,
expires_extra_ms,
job_key_prefix,
job_message_id_prefix,
result_key_prefix,
stream_key_suffix,
)
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
from .lua import publish_job_lua
from .utils import timestamp_ms, to_ms, to_unix_ms

logger = logging.getLogger('arq.connections')
Expand Down Expand Up @@ -165,20 +173,63 @@ async def enqueue_job(
elif defer_by_ms:
score = enqueue_time_ms + defer_by_ms
else:
score = enqueue_time_ms
score = None

expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms
expires_ms = expires_ms or (score or enqueue_time_ms) - enqueue_time_ms + self.expires_extra_ms

job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
job = serialize_job(
function,
args,
kwargs,
_job_try,
enqueue_time_ms,
serializer=self.job_serializer,
)
pipe.multi()
pipe.psetex(job_key, expires_ms, job)
pipe.zadd(_queue_name, {job_id: score})

if score is not None:
pipe.zadd(_queue_name, {job_id: score})
else:
stream_key = _queue_name + stream_key_suffix
job_message_id_key = job_message_id_prefix + job_id
pipe.eval(
publish_job_lua,
2,
# keys
stream_key,
job_message_id_key,
# args
job_id,
str(enqueue_time_ms),
str(expires_ms),
)

try:
await pipe.execute()
except WatchError:
# job got enqueued since we checked 'job_exists'
return None
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
return Job(
job_id,
redis=self,
_queue_name=_queue_name,
_deserializer=self.job_deserializer,
)

async def get_queue_size(self, queue_name: str | None = None, include_delayed_tasks: bool = True) -> int:
if queue_name is None:
queue_name = self.default_queue_name

async with self.pipeline(transaction=True) as pipe:
pipe.xlen(queue_name + stream_key_suffix)
pipe.zcount(queue_name, '-inf', '+inf')
stream_size, delayed_queue_size = await pipe.execute()

if not include_delayed_tasks:
return stream_size

return stream_size + delayed_queue_size

async def _get_job_result(self, key: bytes) -> JobResult:
job_id = key[len(result_key_prefix) :].decode()
Expand Down Expand Up @@ -213,7 +264,16 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]
"""
if queue_name is None:
queue_name = self.default_queue_name
jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1)

async with self.pipeline(transaction=True) as pipe:
pipe.zrange(queue_name, withscores=True, start=0, end=-1)
pipe.xrange(queue_name + stream_key_suffix, '-', '+')
delayed_jobs, stream_jobs = await pipe.execute()

jobs = [
*delayed_jobs,
*[(j[b'job_id'], int(j[b'score'])) for _, j in stream_jobs],
]
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])


Expand Down
3 changes: 3 additions & 0 deletions arq/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
default_queue_name = 'arq:queue'
job_key_prefix = 'arq:job:'
in_progress_key_prefix = 'arq:in-progress:'
job_message_id_prefix = 'arq:message-id:'
result_key_prefix = 'arq:result:'
retry_key_prefix = 'arq:retry:'
abort_jobs_ss = 'arq:abort'
stream_key_suffix = ':stream'
default_consumer_group = 'arq:workers'
# age of items in the abort_key sorted set after which they're deleted
abort_job_max_age = 60
health_check_key_suffix = ':health-check'
Expand Down
45 changes: 39 additions & 6 deletions arq/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@

from redis.asyncio import Redis

from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix
from .constants import (
abort_jobs_ss,
default_queue_name,
in_progress_key_prefix,
job_key_prefix,
job_message_id_prefix,
result_key_prefix,
stream_key_suffix,
)
from .lua import get_job_from_stream_lua
from .utils import ms_to_datetime, poll, timestamp_ms

logger = logging.getLogger('arq.jobs')
Expand Down Expand Up @@ -63,6 +72,10 @@ class JobResult(JobDef):
queue_name: str


def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]:
return dict(zip(input_list[::2], input_list[1::2], strict=True))


class Job:
"""
Holds data a reference to a job.
Expand Down Expand Up @@ -105,7 +118,8 @@ async def result(
async with self._redis.pipeline(transaction=True) as tr:
tr.get(result_key_prefix + self.job_id)
tr.zscore(self._queue_name, self.job_id)
v, s = await tr.execute()
tr.get(job_message_id_prefix + self.job_id)
v, s, m = await tr.execute()

if v:
info = deserialize_result(v, deserializer=self._deserializer)
Expand All @@ -115,7 +129,7 @@ async def result(
raise info.result
else:
raise SerializationError(info.result)
elif s is None:
elif s is None and m is None:
raise ResultNotFound(
'Not waiting for job result because the job is not in queue. '
'Is the worker function configured to keep result?'
Expand All @@ -134,8 +148,24 @@ async def info(self) -> Optional[JobDef]:
if v:
info = deserialize_job(v, deserializer=self._deserializer)
if info:
s = await self._redis.zscore(self._queue_name, self.job_id)
info.score = None if s is None else int(s)
async with self._redis.pipeline(transaction=True) as tr:
tr.zscore(self._queue_name, self.job_id)
tr.eval(
get_job_from_stream_lua,
2,
self._queue_name + stream_key_suffix,
job_message_id_prefix + self.job_id,
)
delayed_score, job_info = await tr.execute()

if delayed_score:
info.score = int(delayed_score)
elif job_info:
_, job_info_payload = job_info
info.score = int(_list_to_dict(job_info_payload)[b'score'])
else:
info.score = None

return info

async def result_info(self) -> Optional[JobResult]:
Expand All @@ -157,12 +187,15 @@ async def status(self) -> JobStatus:
tr.exists(result_key_prefix + self.job_id)
tr.exists(in_progress_key_prefix + self.job_id)
tr.zscore(self._queue_name, self.job_id)
is_complete, is_in_progress, score = await tr.execute()
tr.exists(job_message_id_prefix + self.job_id)
is_complete, is_in_progress, score, queued = await tr.execute()

if is_complete:
return JobStatus.complete
elif is_in_progress:
return JobStatus.in_progress
elif queued:
return JobStatus.queued
elif score:
return JobStatus.deferred if score > timestamp_ms() else JobStatus.queued
else:
Expand Down
48 changes: 48 additions & 0 deletions arq/lua.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
publish_delayed_job_lua = """
local delayed_queue_key = KEYS[1]
local stream_key = KEYS[2]
local job_message_id_key = KEYS[3]
local job_id = ARGV[1]
local job_message_id_expire_ms = ARGV[2]
local score = redis.call('zscore', delayed_queue_key, job_id)
if score == nil or score == false then
return 0
end
local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score)
redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms)
redis.call('zrem', delayed_queue_key, job_id)
return 1
"""

publish_job_lua = """
local stream_key = KEYS[1]
local job_message_id_key = KEYS[2]
local job_id = ARGV[1]
local score = ARGV[2]
local job_message_id_expire_ms = ARGV[3]
local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score)
redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms)
return message_id
"""

get_job_from_stream_lua = """
local stream_key = KEYS[1]
local job_message_id_key = KEYS[2]
local message_id = redis.call('get', job_message_id_key)
if message_id == false then
return nil
end
local job = redis.call('xrange', stream_key, message_id, message_id)
if job == nil then
return nil
end
return job[1]
"""
Loading

0 comments on commit 00ea322

Please sign in to comment.