Skip to content

Commit

Permalink
Create and provide asyncio.Lock directly in scheduler.execute
Browse files Browse the repository at this point in the history
To make sure that the forward_model_ok_lock is assigned to the correct running loop, this creates and provide the lock to the job directly.
  • Loading branch information
xjules committed Aug 30, 2024
1 parent 09cdd6b commit f4ffb68
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 14 deletions.
9 changes: 7 additions & 2 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
timeout_task.cancel()
sem.release()

async def run(self, sem: asyncio.BoundedSemaphore, max_submit: int = 1) -> None:
async def run(
self,
sem: asyncio.BoundedSemaphore,
forward_model_ok_lock: asyncio.Lock,
max_submit: int = 1,
) -> None:
self._requested_max_submit = max_submit
for attempt in range(max_submit):
await self._submit_and_run_once(sem)
Expand All @@ -147,7 +152,7 @@ async def run(self, sem: asyncio.BoundedSemaphore, max_submit: int = 1) -> None:
if self.returncode.result() == 0:
if self._scheduler._manifest_queue is not None:
await self._verify_checksum()
async with self._scheduler._forward_model_ok_lock:
async with forward_model_ok_lock:
await self._handle_finished_forward_model()
break

Expand Down
10 changes: 5 additions & 5 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def __init__(
self._completed_jobs_num: int = 0
self.completed_jobs: asyncio.Queue[int] = asyncio.Queue()

# this lock is to assure that no more than 1 task
# does internalization at a time
self._forward_model_ok_lock: asyncio.Lock = asyncio.Lock()

self._cancelled = False
if max_submit < 0:
raise ValueError(
Expand Down Expand Up @@ -279,9 +275,13 @@ async def execute(
scheduling_tasks.append(asyncio.create_task(self._update_avg_job_runtime()))

sem = asyncio.BoundedSemaphore(self._max_running or len(self._jobs))
# this lock is to assure that no more than 1 task
# does internalization at a time
forward_model_ok_lock = asyncio.Lock()
for iens, job in self._jobs.items():
self._job_tasks[iens] = asyncio.create_task(
job.run(sem, self._max_submit), name=f"job-{iens}_task"
job.run(sem, forward_model_ok_lock, self._max_submit),
name=f"job-{iens}_task",
)

try:
Expand Down
26 changes: 19 additions & 7 deletions tests/unit_tests/scheduler/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def load_result(_):
job.started.set()

job_run_task = asyncio.create_task(
job.run(asyncio.Semaphore(), max_submit=max_submit)
job.run(asyncio.Semaphore(), asyncio.Lock(), max_submit=max_submit)
)

for attempt in range(max_submit):
Expand Down Expand Up @@ -176,7 +176,9 @@ async def test_num_cpu_is_propagated_to_driver(realization: Realization):
realization.num_cpu = 8
scheduler = create_scheduler()
job = Job(scheduler, realization)
job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1))
job_run_task = asyncio.create_task(
job.run(asyncio.Semaphore(), asyncio.Lock(), max_submit=1)
)
job.started.set()
job.returncode.set_result(0)
await job_run_task
Expand All @@ -197,7 +199,9 @@ async def test_realization_memory_is_propagated_to_driver(realization: Realizati
realization.realization_memory = 8 * 1024**2
scheduler = create_scheduler()
job = Job(scheduler, realization)
job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1))
job_run_task = asyncio.create_task(
job.run(asyncio.Semaphore(), asyncio.Lock(), max_submit=1)
)
job.started.set()
job.returncode.set_result(0)
await job_run_task
Expand Down Expand Up @@ -233,7 +237,9 @@ async def test_when_waiting_for_disk_sync_times_out_an_error_is_logged(
job.started.set()

with captured_logs(log_msgs, logging.ERROR):
job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1))
job_run_task = asyncio.create_task(
job.run(asyncio.Semaphore(), asyncio.Lock(), max_submit=1)
)
job.started.set()
job.returncode.set_result(0)
await job_run_task
Expand Down Expand Up @@ -262,7 +268,9 @@ async def test_when_files_in_manifest_are_not_created_an_error_is_logged(
job.started.set()

with captured_logs(log_msgs, logging.ERROR):
job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1))
job_run_task = asyncio.create_task(
job.run(asyncio.Semaphore(), asyncio.Lock(), max_submit=1)
)
job.started.set()
job.returncode.set_result(0)
await job_run_task
Expand Down Expand Up @@ -294,7 +302,9 @@ async def test_when_checksums_do_not_match_a_warning_is_logged(
job.started.set()

with captured_logs(log_msgs, logging.WARNING):
job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1))
job_run_task = asyncio.create_task(
job.run(asyncio.Semaphore(), asyncio.Lock(), max_submit=1)
)
job.started.set()
job.returncode.set_result(0)
await job_run_task
Expand All @@ -320,7 +330,9 @@ async def test_when_no_checksum_info_is_received_a_warning_is_logged(
mocker.patch("asyncio.sleep", return_value=None)

with captured_logs(log_msgs, logging.WARNING):
job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1))
job_run_task = asyncio.create_task(
job.run(asyncio.Semaphore(), asyncio.Lock(), max_submit=1)
)
job.started.set()
job.returncode.set_result(0)
await job_run_task
Expand Down

0 comments on commit f4ffb68

Please sign in to comment.