Skip to content

Commit

Permalink
Fix max_in_flight functionality. (#324)
Browse files Browse the repository at this point in the history
Currently, it's effectively always 1.
  • Loading branch information
Gal Topper authored Dec 9, 2021
1 parent 06855c3 commit 6272997
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
23 changes: 19 additions & 4 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import aiohttp

from .dtypes import _termination_obj, Event, FlowError, V3ioError
from .queue import AsyncQueue
from .table import Table
from .utils import _split_path, get_in, update_in, stringify_key

Expand Down Expand Up @@ -620,15 +621,27 @@ def _init(self):
self._lazy_init_complete = False

async def _worker(self):
async def handle_job(job):
if job is _termination_obj:
return
event = job[0]
completed = await job[1]
await self._handle_completed(event, completed)

event = None
try:
while True:
job = await self._q.get()
# If we don't handle the event before we remove it from the queue, the effective max_in_flight will
# be 1 higher than requested. Hence, we peek.
job = await self._q.peek()
if job is _termination_obj:
await self._q.get()
break
event = job[0]
completed = await job[1]
await self._handle_completed(event, completed)
await self._q.get()

except BaseException as ex:
if event and event._awaitable_result:
none_or_coroutine = event._awaitable_result._set_error(ex)
Expand Down Expand Up @@ -676,7 +689,7 @@ async def _do(self, event):
self._lazy_init_complete = True

if not self._q and self._queue_size > 0:
self._q = asyncio.queues.Queue(self._queue_size)
self._q = AsyncQueue(self._queue_size)
self._worker_awaitable = asyncio.get_running_loop().create_task(self._worker())

if self._queue_size > 0 and self._worker_awaitable.done():
Expand All @@ -691,9 +704,11 @@ async def _do(self, event):
else:
coroutine = self._process_event_with_retries(event)
if self._queue_size == 0:
await coroutine
completed = await coroutine
await self._handle_completed(event, completed)
else:
await self._q.put((event, coroutine))
task = asyncio.get_running_loop().create_task(coroutine)
await self._q.put((event, task))
if self._worker_awaitable.done():
await self._worker_awaitable

Expand Down
39 changes: 39 additions & 0 deletions storey/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import asyncio


class AsyncQueue(asyncio.Queue):
"""
asyncio.Queue with a peek method added.
"""

async def peek(self):
while self.empty():
getter = self._loop.create_future()
self._getters.append(getter)
try:
await getter
except: # noqa: E722
getter.cancel() # Just in case getter is not done yet.
try:
# Clean self._getters from canceled getters.
self._getters.remove(getter)
except ValueError:
# The getter could be removed from self._getters by a
# previous put_nowait call.
pass
if not self.empty() and not getter.cancelled():
# We were woken up by put_nowait(), but can't take
# the call. Wake up the next in line.
self._wakeup_next(self._getters)
raise
return self.peek_nowait()

def peek_nowait(self):
if self.empty():
raise asyncio.QueueEmpty
item = self._peek()
self._wakeup_next(self._putters)
return item

def _peek(self):
return self._queue[0]
9 changes: 6 additions & 3 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2921,13 +2921,14 @@ async def _handle_completed(self, event, response):


# ML-1506
@pytest.mark.parametrize('max_in_flight', [1, 2])
@pytest.mark.parametrize('max_in_flight', [1, 2, 4])
def test_concurrent_execution_max_in_flight(max_in_flight):
class _TestConcurrentExecution(_ConcurrentJobExecution):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._ongoing_processing = 0
self.lazy_init_called = 0
self.handle_completed_called = 0

async def _lazy_init(self):
self.lazy_init_called += 1
Expand All @@ -2939,20 +2940,22 @@ async def _process_event(self, event):
self._ongoing_processing -= 1

async def _handle_completed(self, event, response):
pass
self.handle_completed_called += 1

concurrent_step = _TestConcurrentExecution(max_in_flight=max_in_flight)
controller = build_flow([
SyncEmitSource(),
concurrent_step,
]).run()

for i in range(max_in_flight + 1):
num_events = max_in_flight + 1
for i in range(num_events):
controller.emit(i)
controller.terminate()
controller.await_termination()

assert concurrent_step.lazy_init_called == 1
assert concurrent_step.handle_completed_called == num_events


class MockLogger:
Expand Down

0 comments on commit 6272997

Please sign in to comment.