Skip to content

Commit

Permalink
Fix error handling in concurrent execution. (#328)
Browse files Browse the repository at this point in the history
Do not fail the flow when push_error is available.
  • Loading branch information
Gal Topper authored Dec 16, 2021
1 parent 85d1693 commit 2d73c15
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 26 deletions.
60 changes: 35 additions & 25 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,35 +621,45 @@ def _init(self):
self._lazy_init_complete = False

async def _worker(self):
async def handle_job(job):
if job is _termination_obj:
return
event = job[0]
completed = await job[1]
await self._handle_completed(event, completed)

event = None
try:
while True:
# If we don't handle the event before we remove it from the queue, the effective max_in_flight will
# be 1 higher than requested. Hence, we peek.
job = await self._q.peek()
if job is _termination_obj:
try:
# If we don't handle the event before we remove it from the queue, the effective max_in_flight will
# be 1 higher than requested. Hence, we peek.
job = await self._q.peek()
if job is _termination_obj:
await self._q.get()
break
event = job[0]
completed = await job[1]
await self._handle_completed(event, completed)
await self._q.get()
break
event = job[0]
completed = await job[1]
await self._handle_completed(event, completed)
await self._q.get()

except BaseException as ex:
if event and event._awaitable_result:
none_or_coroutine = event._awaitable_result._set_error(ex)
if none_or_coroutine:
await none_or_coroutine
if not self._q.empty():
await self._q.get()
raise ex
except BaseException as ex:
await self._q.get()
ex._raised_by_storey_step = self
recovery_step = self._get_recovery_step(ex)
try:
if recovery_step is not None:
event.origin_state = self.name
event.error = ex
return await recovery_step._do(event)
else:
if event._awaitable_result:
none_or_coroutine = event._awaitable_result._set_error(ex)
if none_or_coroutine:
await none_or_coroutine
if self.context and hasattr(self.context, 'push_error'):
message = traceback.format_exc()
if self.logger:
self.logger.error(f'Pushing error to error stream: {ex}\n{message}')
self.context.push_error(event, f"{ex}\n{message}", source=self.name)
else:
raise ex
except BaseException:
if not self._q.empty():
await self._q.get()
raise
finally:
await self._cleanup()

Expand Down
62 changes: 61 additions & 1 deletion tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Event, Batch, Table, CSVTarget, DataframeSource, MapClass, JoinWithTable, ReduceToDataFrame, ToDataFrame, \
ParquetTarget, QueryByKey, \
TSDBTarget, Extend, SendToHttp, HttpRequest, NoSqlTarget, NoopDriver, Driver, Recover, V3ioDriver, ParquetSource
from storey.flow import _ConcurrentJobExecution
from storey.flow import _ConcurrentJobExecution, Context


class ATestException(Exception):
Expand Down Expand Up @@ -2958,6 +2958,66 @@ async def _handle_completed(self, event, response):
assert concurrent_step.handle_completed_called == num_events


def test_concurrent_execution_max_in_flight_error():
class _TestConcurrentExecution(_ConcurrentJobExecution):
async def _process_event(self, event):
raise ATestException()

async def _handle_completed(self, event, response):
pass

concurrent_step = _TestConcurrentExecution(max_in_flight=2)
controller = build_flow([
SyncEmitSource(),
concurrent_step,
Complete()
]).run()

awaitable_result = controller.emit(0)
with pytest.raises(ATestException):
awaitable_result.await_result()
controller.terminate()
with pytest.raises(ATestException):
controller.await_termination()


def test_concurrent_execution_max_in_flight_push_error():
class _TestConcurrentExecution(_ConcurrentJobExecution):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._should_raise = True

async def _process_event(self, event):
if self._should_raise:
self._should_raise = False
raise ATestException()

async def _handle_completed(self, event, response):
await self._do_downstream(event)

class ContextWithPushError(Context):
def push_error(self, event, message, source):
pass

context = ContextWithPushError()

concurrent_step = _TestConcurrentExecution(max_in_flight=2, context=context)
controller = build_flow([
SyncEmitSource(),
concurrent_step,
Complete()
]).run()

awaitable_result = controller.emit(0)
with pytest.raises(ATestException):
awaitable_result.await_result()
for i in range(1, 5):
awaitable_result = controller.emit(i)
awaitable_result.await_result()
controller.terminate()
controller.await_termination()


class MockLogger:
def __init__(self):
self.logs = []
Expand Down

0 comments on commit 2d73c15

Please sign in to comment.