Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite Choice step to make it usable from mlrun #537

Merged
merged 6 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 51 additions & 48 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,16 @@ def _event_string(event):
def _should_terminate(self):
return self._termination_received == len(self._inlets)

async def _do_downstream(self, event):
if not self._outlets:
async def _do_downstream(self, event, outlets=None):
outlets = self._outlets if outlets is None else outlets
if not outlets:
return
if event is _termination_obj:
# Only propagate the termination object once we received one per inlet
self._outlets[0]._termination_received += 1
if self._outlets[0]._should_terminate():
self._termination_result = await self._outlets[0]._do(_termination_obj)
for outlet in self._outlets[1:] + self._get_recovery_steps():
outlets[0]._termination_received += 1
if outlets[0]._should_terminate():
self._termination_result = await outlets[0]._do(_termination_obj)
for outlet in outlets[1:] + self._get_recovery_steps():
outlet._termination_received += 1
if outlet._should_terminate():
self._termination_result = self._termination_result_fn(
Expand All @@ -269,28 +270,28 @@ async def _do_downstream(self, event):
return self._termination_result
# If there is more than one outlet, allow concurrent execution.
tasks = []
if len(self._outlets) > 1:
if len(outlets) > 1:
awaitable_result = event._awaitable_result
event._awaitable_result = None
original_events = getattr(event, "_original_events", None)
# Temporarily delete self-reference to avoid deepcopy getting stuck in an infinite loop
event._original_events = None
for i in range(1, len(self._outlets)):
for i in range(1, len(outlets)):
event_copy = copy.deepcopy(event)
event_copy._awaitable_result = awaitable_result
event_copy._original_events = original_events
tasks.append(asyncio.get_running_loop().create_task(self._outlets[i]._do_and_recover(event_copy)))
tasks.append(asyncio.get_running_loop().create_task(outlets[i]._do_and_recover(event_copy)))
# Set self-reference back after deepcopy
event._original_events = original_events
event._awaitable_result = awaitable_result
if self.verbose and self.logger:
step_name = self.name
event_string = self._event_string(event)
self.logger.debug(f"{step_name} -> {self._outlets[0].name} | {event_string}")
await self._outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet.
self.logger.debug(f"{step_name} -> {outlets[0].name} | {event_string}")
await outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet.
for i, task in enumerate(tasks, start=1):
if self.verbose and self.logger:
self.logger.debug(f"{step_name} -> {self._outlets[i].name} | {event_string}")
self.logger.debug(f"{step_name} -> {outlets[i].name} | {event_string}")
await task

def _get_event_or_body(self, event):
Expand Down Expand Up @@ -347,46 +348,48 @@ def _get_uuid(self):


class Choice(Flow):
"""Redirects each input element into at most one of multiple downstreams.

:param choice_array: a list of (downstream, condition) tuples, where downstream is a step and condition is a
function. The first condition in the list to evaluate as true for an input element causes that element to
be redirected to that downstream step.
:type choice_array: tuple of (Flow, Function (Event=>boolean))
:param default: a default step for events that did not match any condition in choice_array. If not set, elements
that don't match any condition will be discarded.
:type default: Flow
:param name: Name of this step, as it should appear in logs. Defaults to class name (Choice).
:type name: string
:param full_event: Whether user functions should receive and return Event objects (when True),
or only the payload (when False). Defaults to False.
:type full_event: boolean
"""
Redirects each input element into any number of predetermined downstream steps. Override select_outlets()
to route events to any number of downstream steps.
"""

def __init__(self, choice_array, default=None, **kwargs):
Flow.__init__(self, **kwargs)

self._choice_array = choice_array
for outlet, _ in choice_array:
self.to(outlet)

if default:
self.to(default)
self._default = default
def _init(self, **kwargs):
super()._init()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure it is not an error: before the change the kwargs were sent at the Flow init so I would expect to see super().__init__(**kwargs) unless it is redundant. I think the outlets param should be passed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the **kwargs need to be removed. There's no kwargs in _init(). Note that _init is unrelated to __init__.

self._name_to_outlet = {}
for outlet in self._outlets:
if outlet.name in self._name_to_outlet:
raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step")
self._name_to_outlet[outlet.name] = outlet
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"]

def select_outlets(self, event):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a type hint and also in the doc string mention that the selected outlets should be returned in a list of strings - the outlet names.

Suggested change
def select_outlets(self, event):
def select_outlets(self, event) -> List[str]:

"""
Override this method to route events based on a customer logic. The default implementation will route all
events to all outlets.
"""
return list(self._name_to_outlet.keys())

async def _do(self, event):
if not self._outlets or event is _termination_obj:
return await super()._do_downstream(event)
chosen_outlet = None
element = self._get_event_or_body(event)
for outlet, condition in self._choice_array:
if condition(element):
chosen_outlet = outlet
break
if chosen_outlet:
await chosen_outlet._do(event)
elif self._default:
await self._default._do(event)
if event is _termination_obj:
return await self._do_downstream(_termination_obj)
else:
event_body = event if self._full_event else event.body
outlet_names = self.select_outlets(event_body)
outlets = []
if self._passthrough_for_preview:
outlet = self._name_to_outlet["dataframe"]
outlets.append(outlet)
Comment on lines +380 to +382
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO here as well to remember to delete this "hack"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really want to duplicate the TODO that's already on the attribute definition.

else:
for outlet_name in outlet_names:
if outlet_name not in self._name_to_outlet:
raise ValueError(
f"select_outlets() returned outlet name '{outlet_name}', which is not one of the "
f"defined outlets: " + ", ".join(self._name_to_outlet)
)
outlet = self._name_to_outlet[outlet_name]
outlets.append(outlet)
return await self._do_downstream(event, outlets=outlets)


class Recover(Flow):
Expand Down
1 change: 1 addition & 0 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ async def _run_loop(self):
await _commit_handled_events(self._outstanding_offsets, committer, commit_all=True)
self._termination_future.set_result(termination_result)
except BaseException as ex:
traceback.print_exc()
if self.logger:
message = "An error was raised"
raised_by = getattr(ex, "_raised_by_storey_step", None)
Expand Down
44 changes: 30 additions & 14 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,26 +1690,42 @@ def boom(_):


def test_choice():
small_reduce = Reduce(0, lambda acc, x: acc + x)
class MyChoice(Choice):
def select_outlets(self, event):
outlets = ["all_events"]
if event > 5:
outlets.append("more_than_five")
else:
outlets.append("up_to_five")
return outlets

big_reduce = build_flow([Map(lambda x: x * 100), Reduce(0, lambda acc, x: acc + x)])
source = SyncEmitSource()
my_choice = MyChoice(termination_result_fn=lambda x, y: x + y)
all_events = Map(lambda x: x, name="all_events")
more_than_five = Map(lambda x: x * 10, name="more_than_five")
up_to_five = Map(lambda x: x * 100, name="up_to_five")
sum_up_all_events = Reduce(0, lambda acc, x: acc + x)
sum_up_more_than_five = Reduce(0, lambda acc, x: acc + x)
sum_up_up_to_five = Reduce(0, lambda acc, x: acc + x)

source.to(my_choice)
my_choice.to(all_events)
my_choice.to(more_than_five)
my_choice.to(up_to_five)
all_events.to(sum_up_all_events)
more_than_five.to(sum_up_more_than_five)
up_to_five.to(sum_up_up_to_five)

controller = build_flow(
[
SyncEmitSource(),
Choice(
[(big_reduce, lambda x: x % 2 == 0)],
default=small_reduce,
termination_result_fn=lambda x, y: x + y,
),
]
).run()
controller = source.run()

for i in range(10):
for i in range(4, 8):
controller.emit(i)

controller.terminate()
termination_result = controller.await_termination()
assert termination_result == 2025

expected = sum(range(4, 8)) + sum(range(6, 8)) * 10 + sum(range(4, 6)) * 100
assert termination_result == expected


def test_metadata():
Expand Down
Loading