From ab2dda2011ab27007650c4918d3704f3bf7ac13d Mon Sep 17 00:00:00 2001 From: Rishabh Mittal Date: Mon, 30 Oct 2023 23:39:26 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20Adding=20a=20job=20counter=20to?= =?UTF-8?q?=20address=20Semaphore=20issues=20(#408)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ๐Ÿ”จ Adding a job counter to address Semaphore issues * ๐Ÿงช Test function for semaphore blocker --- arq/worker.py | 25 +++++++++++++++++++++---- tests/test_worker.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/arq/worker.py b/arq/worker.py index 81afd5b7..398409b5 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -236,7 +236,11 @@ def __init__( self.on_job_start = on_job_start self.on_job_end = on_job_end self.after_job_end = after_job_end - self.sem = asyncio.BoundedSemaphore(max_jobs) + + self.max_jobs = max_jobs + self.sem = asyncio.BoundedSemaphore(max_jobs + 1) + self.job_counter: int = 0 + self.job_timeout_s = to_seconds(job_timeout) self.keep_result_s = to_seconds(keep_result) self.keep_result_forever = keep_result_forever @@ -374,13 +378,13 @@ async def _poll_iteration(self) -> None: return count = min(burst_jobs_remaining, count) if self.allow_pick_jobs: - async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs + if self.job_counter < self.max_jobs: now = timestamp_ms() job_ids = await self.pool.zrangebyscore( self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now ) - await self.start_jobs(job_ids) + await self.start_jobs(job_ids) if self.allow_abort_jobs: await self._cancel_aborted_jobs() @@ -419,12 +423,23 @@ async def _cancel_aborted_jobs(self) -> None: self.aborting_tasks.update(aborted) await self.pool.zrem(abort_jobs_ss, *aborted) + def _release_sem_dec_counter_on_complete(self) -> None: + self.job_counter = self.job_counter - 1 + self.sem.release() + async def start_jobs(self, job_ids: List[bytes]) -> None: """ For each job id, get the job definition, check it's not running and start it in a task """ for job_id_b in job_ids: await self.sem.acquire() + + if self.job_counter >= self.max_jobs: + self.sem.release() + return None + + self.job_counter = self.job_counter + 1 + job_id = job_id_b.decode() in_progress_key = in_progress_key_prefix + job_id async with self.pool.pipeline(transaction=True) as pipe: @@ -433,6 +448,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: score = await pipe.zscore(self.queue_name, job_id) if ongoing_exists or not score: # job already started elsewhere, or already finished and removed from queue + self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('job %s already running elsewhere', job_id) continue @@ -445,11 +461,12 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: await pipe.execute() except (ResponseError, WatchError): # job already started elsewhere since we got 'existing' + self.job_counter = self.job_counter - 1 self.sem.release() logger.debug('multi-exec error, job %s already started elsewhere', job_id) else: t = self.loop.create_task(self.run_job(job_id, int(score))) - t.add_done_callback(lambda _: self.sem.release()) + t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete()) self.tasks[job_id] = t async def run_job(self, job_id: str, score: int) -> None: # noqa: C901 diff --git a/tests/test_worker.py b/tests/test_worker.py index aa56085b..23dd91d2 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -984,6 +984,36 @@ async def test(ctx): assert result['called'] == 4 +async def test_job_cancel_on_max_jobs(arq_redis: ArqRedis, worker, caplog): + async def longfunc(ctx): + await asyncio.sleep(3600) + + async def wait_and_abort(job, delay=0.1): + await asyncio.sleep(delay) + assert await job.abort() is True + + caplog.set_level(logging.INFO) + await arq_redis.zadd(abort_jobs_ss, {b'foobar': int(1e9)}) + job = await arq_redis.enqueue_job('longfunc', _job_id='testing') + + worker: Worker = worker( + functions=[func(longfunc, name='longfunc')], allow_abort_jobs=True, poll_delay=0.1, max_jobs=1 + ) + assert worker.jobs_complete == 0 + assert worker.jobs_failed == 0 + assert worker.jobs_retried == 0 + await asyncio.gather(wait_and_abort(job), worker.main()) + await worker.main() + assert worker.jobs_complete == 0 + assert worker.jobs_failed == 1 + assert worker.jobs_retried == 0 + log = re.sub(r'\d+.\d\ds', 'X.XXs', '\n'.join(r.message for r in caplog.records)) + assert 'X.XXs โ†’ testing:longfunc()\n X.XXs โŠ˜ testing:longfunc aborted' in log + assert worker.aborting_tasks == set() + assert worker.tasks == {} + assert worker.job_tasks == {} + + async def test_worker_timezone_defaults_to_system_timezone(worker): worker = worker(functions=[func(foobar)]) assert worker.timezone is not None