Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Redis Streams for immediate task delivery #492

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 67 additions & 7 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,16 @@
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import RedisError, WatchError

from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix
from .constants import (
default_queue_name,
expires_extra_ms,
job_key_prefix,
job_message_id_prefix,
result_key_prefix,
stream_key_suffix,
)
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
from .lua import publish_job_lua
from .utils import timestamp_ms, to_ms, to_unix_ms

logger = logging.getLogger('arq.connections')
Expand Down Expand Up @@ -165,20 +173,63 @@ async def enqueue_job(
elif defer_by_ms:
score = enqueue_time_ms + defer_by_ms
else:
score = enqueue_time_ms
score = None

expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms
expires_ms = expires_ms or (score or enqueue_time_ms) - enqueue_time_ms + self.expires_extra_ms

job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
job = serialize_job(
function,
args,
kwargs,
_job_try,
enqueue_time_ms,
serializer=self.job_serializer,
)
pipe.multi()
pipe.psetex(job_key, expires_ms, job)
pipe.zadd(_queue_name, {job_id: score})

if score is not None:
pipe.zadd(_queue_name, {job_id: score})
else:
stream_key = _queue_name + stream_key_suffix
job_message_id_key = job_message_id_prefix + job_id
pipe.eval(
publish_job_lua,
2,
# keys
stream_key,
job_message_id_key,
# args
job_id,
str(enqueue_time_ms),
str(expires_ms),
)

try:
await pipe.execute()
except WatchError:
# job got enqueued since we checked 'job_exists'
return None
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
return Job(
job_id,
redis=self,
_queue_name=_queue_name,
_deserializer=self.job_deserializer,
)

async def get_queue_size(self, queue_name: str | None = None, include_delayed_tasks: bool = True) -> int:
if queue_name is None:
queue_name = self.default_queue_name

async with self.pipeline(transaction=True) as pipe:
pipe.xlen(queue_name + stream_key_suffix)
pipe.zcount(queue_name, '-inf', '+inf')
stream_size, delayed_queue_size = await pipe.execute()

if not include_delayed_tasks:
return stream_size

return stream_size + delayed_queue_size

async def _get_job_result(self, key: bytes) -> JobResult:
job_id = key[len(result_key_prefix) :].decode()
Expand Down Expand Up @@ -213,7 +264,16 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]
"""
if queue_name is None:
queue_name = self.default_queue_name
jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1)

async with self.pipeline(transaction=True) as pipe:
pipe.zrange(queue_name, withscores=True, start=0, end=-1)
pipe.xrange(queue_name + stream_key_suffix, '-', '+')
delayed_jobs, stream_jobs = await pipe.execute()

jobs = [
*delayed_jobs,
*[(j[b'job_id'], int(j[b'score'])) for _, j in stream_jobs],
]
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])


Expand Down
3 changes: 3 additions & 0 deletions arq/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
default_queue_name = 'arq:queue'
job_key_prefix = 'arq:job:'
in_progress_key_prefix = 'arq:in-progress:'
job_message_id_prefix = 'arq:message-id:'
result_key_prefix = 'arq:result:'
retry_key_prefix = 'arq:retry:'
abort_jobs_ss = 'arq:abort'
stream_key_suffix = ':stream'
default_consumer_group = 'arq:workers'
# age of items in the abort_key sorted set after which they're deleted
abort_job_max_age = 60
health_check_key_suffix = ':health-check'
Expand Down
45 changes: 39 additions & 6 deletions arq/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@

from redis.asyncio import Redis

from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix
from .constants import (
abort_jobs_ss,
default_queue_name,
in_progress_key_prefix,
job_key_prefix,
job_message_id_prefix,
result_key_prefix,
stream_key_suffix,
)
from .lua import get_job_from_stream_lua
from .utils import ms_to_datetime, poll, timestamp_ms

logger = logging.getLogger('arq.jobs')
Expand Down Expand Up @@ -63,6 +72,10 @@ class JobResult(JobDef):
queue_name: str


def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]:
return dict(zip(input_list[::2], input_list[1::2], strict=True))


class Job:
"""
Holds data a reference to a job.
Expand Down Expand Up @@ -105,7 +118,8 @@ async def result(
async with self._redis.pipeline(transaction=True) as tr:
tr.get(result_key_prefix + self.job_id)
tr.zscore(self._queue_name, self.job_id)
v, s = await tr.execute()
tr.get(job_message_id_prefix + self.job_id)
v, s, m = await tr.execute()

if v:
info = deserialize_result(v, deserializer=self._deserializer)
Expand All @@ -115,7 +129,7 @@ async def result(
raise info.result
else:
raise SerializationError(info.result)
elif s is None:
elif s is None and m is None:
raise ResultNotFound(
'Not waiting for job result because the job is not in queue. '
'Is the worker function configured to keep result?'
Expand All @@ -134,8 +148,24 @@ async def info(self) -> Optional[JobDef]:
if v:
info = deserialize_job(v, deserializer=self._deserializer)
if info:
s = await self._redis.zscore(self._queue_name, self.job_id)
info.score = None if s is None else int(s)
async with self._redis.pipeline(transaction=True) as tr:
tr.zscore(self._queue_name, self.job_id)
tr.eval(
get_job_from_stream_lua,
2,
self._queue_name + stream_key_suffix,
job_message_id_prefix + self.job_id,
)
delayed_score, job_info = await tr.execute()

if delayed_score:
info.score = int(delayed_score)
elif job_info:
_, job_info_payload = job_info
info.score = int(_list_to_dict(job_info_payload)[b'score'])
else:
info.score = None

return info

async def result_info(self) -> Optional[JobResult]:
Expand All @@ -157,12 +187,15 @@ async def status(self) -> JobStatus:
tr.exists(result_key_prefix + self.job_id)
tr.exists(in_progress_key_prefix + self.job_id)
tr.zscore(self._queue_name, self.job_id)
is_complete, is_in_progress, score = await tr.execute()
tr.exists(job_message_id_prefix + self.job_id)
is_complete, is_in_progress, score, queued = await tr.execute()

if is_complete:
return JobStatus.complete
elif is_in_progress:
return JobStatus.in_progress
elif queued:
return JobStatus.queued
elif score:
return JobStatus.deferred if score > timestamp_ms() else JobStatus.queued
else:
Expand Down
48 changes: 48 additions & 0 deletions arq/lua.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
publish_delayed_job_lua = """
local delayed_queue_key = KEYS[1]
local stream_key = KEYS[2]
local job_message_id_key = KEYS[3]

local job_id = ARGV[1]
local job_message_id_expire_ms = ARGV[2]

local score = redis.call('zscore', delayed_queue_key, job_id)
if score == nil or score == false then
return 0
end

local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score)
redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms)
redis.call('zrem', delayed_queue_key, job_id)
return 1
"""

publish_job_lua = """
local stream_key = KEYS[1]
local job_message_id_key = KEYS[2]

local job_id = ARGV[1]
local score = ARGV[2]
local job_message_id_expire_ms = ARGV[3]

local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score)
redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms)
return message_id
"""

get_job_from_stream_lua = """
local stream_key = KEYS[1]
local job_message_id_key = KEYS[2]

local message_id = redis.call('get', job_message_id_key)
if message_id == false then
return nil
end

local job = redis.call('xrange', stream_key, message_id, message_id)
if job == nil then
return nil
end

return job[1]
"""
Loading
Loading