From 4e6645d7de41a5d75290d06c2a8faeceb3d5637c Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Wed, 11 Sep 2024 20:29:28 +0800 Subject: [PATCH 1/5] Rewrite Choice step to make it usable from mlrun [ML-7818](https://iguazio.atlassian.net/browse/ML-7818) --- storey/flow.py | 86 ++++++++++++++++++++-------------------------- storey/sources.py | 1 + tests/test_flow.py | 44 ++++++++++++++++-------- 3 files changed, 69 insertions(+), 62 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 8b616a17..d8798e9e 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -252,15 +252,16 @@ def _event_string(event): def _should_terminate(self): return self._termination_received == len(self._inlets) - async def _do_downstream(self, event): - if not self._outlets: + async def _do_downstream(self, event, outlets=None): + outlets = self._outlets if outlets is None else outlets + if not outlets: return if event is _termination_obj: # Only propagate the termination object once we received one per inlet - self._outlets[0]._termination_received += 1 - if self._outlets[0]._should_terminate(): - self._termination_result = await self._outlets[0]._do(_termination_obj) - for outlet in self._outlets[1:] + self._get_recovery_steps(): + outlets[0]._termination_received += 1 + if outlets[0]._should_terminate(): + self._termination_result = await outlets[0]._do(_termination_obj) + for outlet in outlets[1:] + self._get_recovery_steps(): outlet._termination_received += 1 if outlet._should_terminate(): self._termination_result = self._termination_result_fn( @@ -269,28 +270,28 @@ async def _do_downstream(self, event): return self._termination_result # If there is more than one outlet, allow concurrent execution. tasks = [] - if len(self._outlets) > 1: + if len(outlets) > 1: awaitable_result = event._awaitable_result event._awaitable_result = None original_events = getattr(event, "_original_events", None) # Temporarily delete self-reference to avoid deepcopy getting stuck in an infinite loop event._original_events = None - for i in range(1, len(self._outlets)): + for i in range(1, len(outlets)): event_copy = copy.deepcopy(event) event_copy._awaitable_result = awaitable_result event_copy._original_events = original_events - tasks.append(asyncio.get_running_loop().create_task(self._outlets[i]._do_and_recover(event_copy))) + tasks.append(asyncio.get_running_loop().create_task(outlets[i]._do_and_recover(event_copy))) # Set self-reference back after deepcopy event._original_events = original_events event._awaitable_result = awaitable_result if self.verbose and self.logger: step_name = self.name event_string = self._event_string(event) - self.logger.debug(f"{step_name} -> {self._outlets[0].name} | {event_string}") - await self._outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet. + self.logger.debug(f"{step_name} -> {outlets[0].name} | {event_string}") + await outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet. for i, task in enumerate(tasks, start=1): if self.verbose and self.logger: - self.logger.debug(f"{step_name} -> {self._outlets[i].name} | {event_string}") + self.logger.debug(f"{step_name} -> {outlets[i].name} | {event_string}") await task def _get_event_or_body(self, event): @@ -347,46 +348,35 @@ def _get_uuid(self): class Choice(Flow): - """Redirects each input element into at most one of multiple downstreams. - - :param choice_array: a list of (downstream, condition) tuples, where downstream is a step and condition is a - function. The first condition in the list to evaluate as true for an input element causes that element to - be redirected to that downstream step. - :type choice_array: tuple of (Flow, Function (Event=>boolean)) - :param default: a default step for events that did not match any condition in choice_array. If not set, elements - that don't match any condition will be discarded. - :type default: Flow - :param name: Name of this step, as it should appear in logs. Defaults to class name (Choice). - :type name: string - :param full_event: Whether user functions should receive and return Event objects (when True), - or only the payload (when False). Defaults to False. - :type full_event: boolean - """ + """Redirects each input element into any number of predetermined downstream steps.""" - def __init__(self, choice_array, default=None, **kwargs): - Flow.__init__(self, **kwargs) - - self._choice_array = choice_array - for outlet, _ in choice_array: - self.to(outlet) + def _init(self, **kwargs): + super()._init() + self._name_to_outlet = {} + for outlet in self._outlets: + if outlet.name in self._name_to_outlet: + raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step") + self._name_to_outlet[outlet.name] = outlet - if default: - self.to(default) - self._default = default + def select_outlets(self, event): + return list(self._name_to_outlet.keys()) async def _do(self, event): - if not self._outlets or event is _termination_obj: - return await super()._do_downstream(event) - chosen_outlet = None - element = self._get_event_or_body(event) - for outlet, condition in self._choice_array: - if condition(element): - chosen_outlet = outlet - break - if chosen_outlet: - await chosen_outlet._do(event) - elif self._default: - await self._default._do(event) + if event is _termination_obj: + return await self._do_downstream(_termination_obj) + else: + event_body = event if self._full_event else event.body + outlet_names = self.select_outlets(event_body) + outlets = [] + for outlet_name in outlet_names: + if outlet_name not in self._name_to_outlet: + raise ValueError( + f"select_outlets() returned outlet name '{outlet_name}', which is not one of the" + f"defined outlets: " + ", ".join(self._name_to_outlet) + ) + outlet = self._name_to_outlet[outlet_name] + outlets.append(outlet) + return await self._do_downstream(event, outlets=outlets) class Recover(Flow): diff --git a/storey/sources.py b/storey/sources.py index 0abca29a..6c10853a 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -313,6 +313,7 @@ async def _run_loop(self): await _commit_handled_events(self._outstanding_offsets, committer, commit_all=True) self._termination_future.set_result(termination_result) except BaseException as ex: + traceback.print_exc() if self.logger: message = "An error was raised" raised_by = getattr(ex, "_raised_by_storey_step", None) diff --git a/tests/test_flow.py b/tests/test_flow.py index bd484069..755597ce 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -1690,26 +1690,42 @@ def boom(_): def test_choice(): - small_reduce = Reduce(0, lambda acc, x: acc + x) + class MyChoice(Choice): + def select_outlets(self, event): + outlets = ["all_events"] + if event > 5: + outlets.append("more_than_five") + else: + outlets.append("up_to_five") + return outlets - big_reduce = build_flow([Map(lambda x: x * 100), Reduce(0, lambda acc, x: acc + x)]) + source = SyncEmitSource() + my_choice = MyChoice(termination_result_fn=lambda x, y: x + y) + all_events = Map(lambda x: x, name="all_events") + more_than_five = Map(lambda x: x * 10, name="more_than_five") + up_to_five = Map(lambda x: x * 100, name="up_to_five") + sum_up_all_events = Reduce(0, lambda acc, x: acc + x) + sum_up_more_than_five = Reduce(0, lambda acc, x: acc + x) + sum_up_up_to_five = Reduce(0, lambda acc, x: acc + x) + + source.to(my_choice) + my_choice.to(all_events) + my_choice.to(more_than_five) + my_choice.to(up_to_five) + all_events.to(sum_up_all_events) + more_than_five.to(sum_up_more_than_five) + up_to_five.to(sum_up_up_to_five) - controller = build_flow( - [ - SyncEmitSource(), - Choice( - [(big_reduce, lambda x: x % 2 == 0)], - default=small_reduce, - termination_result_fn=lambda x, y: x + y, - ), - ] - ).run() + controller = source.run() - for i in range(10): + for i in range(4, 8): controller.emit(i) + controller.terminate() termination_result = controller.await_termination() - assert termination_result == 2025 + + expected = sum(range(4, 8)) + sum(range(6, 8)) * 10 + sum(range(4, 6)) * 100 + assert termination_result == expected def test_metadata(): From 59fa81aca321fa9ad3c81d20e5cf1ec1dbf50178 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Sun, 15 Sep 2024 14:13:11 +0800 Subject: [PATCH 2/5] Add missing space --- storey/flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/storey/flow.py b/storey/flow.py index d8798e9e..52341481 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -371,7 +371,7 @@ async def _do(self, event): for outlet_name in outlet_names: if outlet_name not in self._name_to_outlet: raise ValueError( - f"select_outlets() returned outlet name '{outlet_name}', which is not one of the" + f"select_outlets() returned outlet name '{outlet_name}', which is not one of the " f"defined outlets: " + ", ".join(self._name_to_outlet) ) outlet = self._name_to_outlet[outlet_name] From 56da4e96e5cbce0cc132103952a412b6fec2e7b5 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Sun, 15 Sep 2024 14:30:58 +0800 Subject: [PATCH 3/5] Hack to avoid issue with mlrun preview --- storey/flow.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 52341481..31da3fe3 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -357,6 +357,8 @@ def _init(self, **kwargs): if outlet.name in self._name_to_outlet: raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step") self._name_to_outlet[outlet.name] = outlet + # TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget + self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] def select_outlets(self, event): return list(self._name_to_outlet.keys()) @@ -368,14 +370,18 @@ async def _do(self, event): event_body = event if self._full_event else event.body outlet_names = self.select_outlets(event_body) outlets = [] - for outlet_name in outlet_names: - if outlet_name not in self._name_to_outlet: - raise ValueError( - f"select_outlets() returned outlet name '{outlet_name}', which is not one of the " - f"defined outlets: " + ", ".join(self._name_to_outlet) - ) - outlet = self._name_to_outlet[outlet_name] + if self._passthrough_for_preview: + outlet = self._name_to_outlet["dataframe"] outlets.append(outlet) + else: + for outlet_name in outlet_names: + if outlet_name not in self._name_to_outlet: + raise ValueError( + f"select_outlets() returned outlet name '{outlet_name}', which is not one of the " + f"defined outlets: " + ", ".join(self._name_to_outlet) + ) + outlet = self._name_to_outlet[outlet_name] + outlets.append(outlet) return await self._do_downstream(event, outlets=outlets) From 207073d4db6205080e41fc63133ce4122e49b3c5 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Sun, 15 Sep 2024 14:37:44 +0800 Subject: [PATCH 4/5] Improve docs --- storey/flow.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/storey/flow.py b/storey/flow.py index 31da3fe3..6b34f10e 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -348,7 +348,10 @@ def _get_uuid(self): class Choice(Flow): - """Redirects each input element into any number of predetermined downstream steps.""" + """ + Redirects each input element into any number of predetermined downstream steps. Override select_outlets() + to route events to any number of downstream steps. + """ def _init(self, **kwargs): super()._init() @@ -361,6 +364,10 @@ def _init(self, **kwargs): self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] def select_outlets(self, event): + """ + Override this method to route events based on a customer logic. The default implementation will route all + events to all outlets. + """ return list(self._name_to_outlet.keys()) async def _do(self, event): From 81163eecbf6b37336c7ab154c05712d241025ee1 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Tue, 15 Oct 2024 12:35:07 +0800 Subject: [PATCH 5/5] Remove accidental kwargs, add type annotation --- storey/flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 6b34f10e..7f2df3f0 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -353,7 +353,7 @@ class Choice(Flow): to route events to any number of downstream steps. """ - def _init(self, **kwargs): + def _init(self): super()._init() self._name_to_outlet = {} for outlet in self._outlets: @@ -363,7 +363,7 @@ def _init(self, **kwargs): # TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] - def select_outlets(self, event): + def select_outlets(self, event) -> List[str]: """ Override this method to route events based on a customer logic. The default implementation will route all events to all outlets.