Skip to content

Commit

Permalink
flow/check_stage_in_flow need to check recovery step as well (#169)
Browse files Browse the repository at this point in the history
* flow/check_stage_in_flow need to check recovery step as well

* bug fix

* bug fix

* adding a test

* minor fix

* minor fix

* test improvement
  • Loading branch information
katyakats authored Feb 15, 2021
1 parent bb00003 commit 92f404d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
15 changes: 11 additions & 4 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,19 @@ def _user_fn_output_to_event(self, event, fn_result):
mapped_event.body = fn_result
return mapped_event

def _check_stage_in_flow(self, type_to_check):
for step in self._outlets:
if isinstance(step, type_to_check):
def _check_step_in_flow(self, type_to_check):
if isinstance(self, type_to_check):
return True
for outlet in self._outlets:
if outlet._check_step_in_flow(type_to_check):
return True
if step._check_stage_in_flow(type_to_check):
if isinstance(self._recovery_step, Flow):
if self._recovery_step._check_step_in_flow(type_to_check):
return True
elif isinstance(self._recovery_step, dict):
for step in self._recovery_step.values():
if step._check_step_in_flow(type_to_check):
return True
return False


Expand Down
4 changes: 2 additions & 2 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def raise_error_or_return_termination_result():
self._raise_on_error(self._termination_q.get())
return self._termination_future.result()

has_complete = self._check_stage_in_flow(Complete)
has_complete = self._check_step_in_flow(Complete)

return FlowController(self._emit, raise_error_or_return_termination_result, has_complete, self._key_field, self._time_field)

Expand Down Expand Up @@ -365,7 +365,7 @@ async def run(self):
"""Starts the flow"""
self._closeables = super().run()
loop_task = asyncio.get_running_loop().create_task(self._run_loop())
has_complete = self._check_stage_in_flow(Complete)
has_complete = self._check_step_in_flow(Complete)
return AsyncFlowController(self._emit, loop_task, has_complete, self._key_field, self._time_field)


Expand Down
24 changes: 24 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2129,3 +2129,27 @@ def test_writer_downstream(tmpdir):
controller.terminate()
result = controller.await_termination()
assert result == 45


def test_complete_in_error_flow():
reduce = build_flow([
Complete(),
Reduce(0, lambda acc, x: acc + x)
])
controller = build_flow([
Source(),
Map(lambda x: x + 1),
Map(RaiseEx(5).raise_ex, recovery_step=reduce),
Map(lambda x: x * 100),
reduce
]).run()

for i in range(10):
awaitable_result = controller.emit(i)
if i == 4:
assert awaitable_result.await_result() == i + 1
else:
assert awaitable_result.await_result() == (i + 1) * 100
controller.terminate()
termination_result = controller.await_termination()
assert termination_result == 5005

0 comments on commit 92f404d

Please sign in to comment.