diff --git a/storey/flow.py b/storey/flow.py index 29db23f2..bb232f52 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -216,12 +216,19 @@ def _user_fn_output_to_event(self, event, fn_result): mapped_event.body = fn_result return mapped_event - def _check_stage_in_flow(self, type_to_check): - for step in self._outlets: - if isinstance(step, type_to_check): + def _check_step_in_flow(self, type_to_check): + if isinstance(self, type_to_check): + return True + for outlet in self._outlets: + if outlet._check_step_in_flow(type_to_check): return True - if step._check_stage_in_flow(type_to_check): + if isinstance(self._recovery_step, Flow): + if self._recovery_step._check_step_in_flow(type_to_check): return True + elif isinstance(self._recovery_step, dict): + for step in self._recovery_step.values(): + if step._check_step_in_flow(type_to_check): + return True return False diff --git a/storey/sources.py b/storey/sources.py index 2b06df90..39333d5b 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -221,7 +221,7 @@ def raise_error_or_return_termination_result(): self._raise_on_error(self._termination_q.get()) return self._termination_future.result() - has_complete = self._check_stage_in_flow(Complete) + has_complete = self._check_step_in_flow(Complete) return FlowController(self._emit, raise_error_or_return_termination_result, has_complete, self._key_field, self._time_field) @@ -365,7 +365,7 @@ async def run(self): """Starts the flow""" self._closeables = super().run() loop_task = asyncio.get_running_loop().create_task(self._run_loop()) - has_complete = self._check_stage_in_flow(Complete) + has_complete = self._check_step_in_flow(Complete) return AsyncFlowController(self._emit, loop_task, has_complete, self._key_field, self._time_field) diff --git a/tests/test_flow.py b/tests/test_flow.py index 600ae411..87fa41eb 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -2129,3 +2129,27 @@ def test_writer_downstream(tmpdir): controller.terminate() result = controller.await_termination() assert result == 45 + + +def test_complete_in_error_flow(): + reduce = build_flow([ + Complete(), + Reduce(0, lambda acc, x: acc + x) + ]) + controller = build_flow([ + Source(), + Map(lambda x: x + 1), + Map(RaiseEx(5).raise_ex, recovery_step=reduce), + Map(lambda x: x * 100), + reduce + ]).run() + + for i in range(10): + awaitable_result = controller.emit(i) + if i == 4: + assert awaitable_result.await_result() == i + 1 + else: + assert awaitable_result.await_result() == (i + 1) * 100 + controller.terminate() + termination_result = controller.await_termination() + assert termination_result == 5005