diff --git a/storey/flow.py b/storey/flow.py index 2c4b559a..e1ded8ed 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -621,35 +621,45 @@ def _init(self): self._lazy_init_complete = False async def _worker(self): - async def handle_job(job): - if job is _termination_obj: - return - event = job[0] - completed = await job[1] - await self._handle_completed(event, completed) - event = None try: while True: - # If we don't handle the event before we remove it from the queue, the effective max_in_flight will - # be 1 higher than requested. Hence, we peek. - job = await self._q.peek() - if job is _termination_obj: + try: + # If we don't handle the event before we remove it from the queue, the effective max_in_flight will + # be 1 higher than requested. Hence, we peek. + job = await self._q.peek() + if job is _termination_obj: + await self._q.get() + break + event = job[0] + completed = await job[1] + await self._handle_completed(event, completed) await self._q.get() - break - event = job[0] - completed = await job[1] - await self._handle_completed(event, completed) - await self._q.get() - - except BaseException as ex: - if event and event._awaitable_result: - none_or_coroutine = event._awaitable_result._set_error(ex) - if none_or_coroutine: - await none_or_coroutine - if not self._q.empty(): - await self._q.get() - raise ex + except BaseException as ex: + await self._q.get() + ex._raised_by_storey_step = self + recovery_step = self._get_recovery_step(ex) + try: + if recovery_step is not None: + event.origin_state = self.name + event.error = ex + return await recovery_step._do(event) + else: + if event._awaitable_result: + none_or_coroutine = event._awaitable_result._set_error(ex) + if none_or_coroutine: + await none_or_coroutine + if self.context and hasattr(self.context, 'push_error'): + message = traceback.format_exc() + if self.logger: + self.logger.error(f'Pushing error to error stream: {ex}\n{message}') + self.context.push_error(event, f"{ex}\n{message}", source=self.name) + else: + raise ex + except BaseException: + if not self._q.empty(): + await self._q.get() + raise finally: await self._cleanup() diff --git a/tests/test_flow.py b/tests/test_flow.py index 0e7f2924..a9033046 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -21,7 +21,7 @@ Event, Batch, Table, CSVTarget, DataframeSource, MapClass, JoinWithTable, ReduceToDataFrame, ToDataFrame, \ ParquetTarget, QueryByKey, \ TSDBTarget, Extend, SendToHttp, HttpRequest, NoSqlTarget, NoopDriver, Driver, Recover, V3ioDriver, ParquetSource -from storey.flow import _ConcurrentJobExecution +from storey.flow import _ConcurrentJobExecution, Context class ATestException(Exception): @@ -2958,6 +2958,66 @@ async def _handle_completed(self, event, response): assert concurrent_step.handle_completed_called == num_events +def test_concurrent_execution_max_in_flight_error(): + class _TestConcurrentExecution(_ConcurrentJobExecution): + async def _process_event(self, event): + raise ATestException() + + async def _handle_completed(self, event, response): + pass + + concurrent_step = _TestConcurrentExecution(max_in_flight=2) + controller = build_flow([ + SyncEmitSource(), + concurrent_step, + Complete() + ]).run() + + awaitable_result = controller.emit(0) + with pytest.raises(ATestException): + awaitable_result.await_result() + controller.terminate() + with pytest.raises(ATestException): + controller.await_termination() + + +def test_concurrent_execution_max_in_flight_push_error(): + class _TestConcurrentExecution(_ConcurrentJobExecution): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._should_raise = True + + async def _process_event(self, event): + if self._should_raise: + self._should_raise = False + raise ATestException() + + async def _handle_completed(self, event, response): + await self._do_downstream(event) + + class ContextWithPushError(Context): + def push_error(self, event, message, source): + pass + + context = ContextWithPushError() + + concurrent_step = _TestConcurrentExecution(max_in_flight=2, context=context) + controller = build_flow([ + SyncEmitSource(), + concurrent_step, + Complete() + ]).run() + + awaitable_result = controller.emit(0) + with pytest.raises(ATestException): + awaitable_result.await_result() + for i in range(1, 5): + awaitable_result = controller.emit(i) + awaitable_result.await_result() + controller.terminate() + controller.await_termination() + + class MockLogger: def __init__(self): self.logs = []