-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite Choice step to make it usable from mlrun #537
Changes from 4 commits
4e6645d
59fa81a
56da4e9
207073d
81163ee
7a544ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -252,15 +252,16 @@ def _event_string(event): | |||||
def _should_terminate(self): | ||||||
return self._termination_received == len(self._inlets) | ||||||
|
||||||
async def _do_downstream(self, event): | ||||||
if not self._outlets: | ||||||
async def _do_downstream(self, event, outlets=None): | ||||||
outlets = self._outlets if outlets is None else outlets | ||||||
if not outlets: | ||||||
return | ||||||
if event is _termination_obj: | ||||||
# Only propagate the termination object once we received one per inlet | ||||||
self._outlets[0]._termination_received += 1 | ||||||
if self._outlets[0]._should_terminate(): | ||||||
self._termination_result = await self._outlets[0]._do(_termination_obj) | ||||||
for outlet in self._outlets[1:] + self._get_recovery_steps(): | ||||||
outlets[0]._termination_received += 1 | ||||||
if outlets[0]._should_terminate(): | ||||||
self._termination_result = await outlets[0]._do(_termination_obj) | ||||||
for outlet in outlets[1:] + self._get_recovery_steps(): | ||||||
outlet._termination_received += 1 | ||||||
if outlet._should_terminate(): | ||||||
self._termination_result = self._termination_result_fn( | ||||||
|
@@ -269,28 +270,28 @@ async def _do_downstream(self, event): | |||||
return self._termination_result | ||||||
# If there is more than one outlet, allow concurrent execution. | ||||||
tasks = [] | ||||||
if len(self._outlets) > 1: | ||||||
if len(outlets) > 1: | ||||||
awaitable_result = event._awaitable_result | ||||||
event._awaitable_result = None | ||||||
original_events = getattr(event, "_original_events", None) | ||||||
# Temporarily delete self-reference to avoid deepcopy getting stuck in an infinite loop | ||||||
event._original_events = None | ||||||
for i in range(1, len(self._outlets)): | ||||||
for i in range(1, len(outlets)): | ||||||
event_copy = copy.deepcopy(event) | ||||||
event_copy._awaitable_result = awaitable_result | ||||||
event_copy._original_events = original_events | ||||||
tasks.append(asyncio.get_running_loop().create_task(self._outlets[i]._do_and_recover(event_copy))) | ||||||
tasks.append(asyncio.get_running_loop().create_task(outlets[i]._do_and_recover(event_copy))) | ||||||
# Set self-reference back after deepcopy | ||||||
event._original_events = original_events | ||||||
event._awaitable_result = awaitable_result | ||||||
if self.verbose and self.logger: | ||||||
step_name = self.name | ||||||
event_string = self._event_string(event) | ||||||
self.logger.debug(f"{step_name} -> {self._outlets[0].name} | {event_string}") | ||||||
await self._outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet. | ||||||
self.logger.debug(f"{step_name} -> {outlets[0].name} | {event_string}") | ||||||
await outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet. | ||||||
for i, task in enumerate(tasks, start=1): | ||||||
if self.verbose and self.logger: | ||||||
self.logger.debug(f"{step_name} -> {self._outlets[i].name} | {event_string}") | ||||||
self.logger.debug(f"{step_name} -> {outlets[i].name} | {event_string}") | ||||||
await task | ||||||
|
||||||
def _get_event_or_body(self, event): | ||||||
|
@@ -347,46 +348,48 @@ def _get_uuid(self): | |||||
|
||||||
|
||||||
class Choice(Flow): | ||||||
"""Redirects each input element into at most one of multiple downstreams. | ||||||
|
||||||
:param choice_array: a list of (downstream, condition) tuples, where downstream is a step and condition is a | ||||||
function. The first condition in the list to evaluate as true for an input element causes that element to | ||||||
be redirected to that downstream step. | ||||||
:type choice_array: tuple of (Flow, Function (Event=>boolean)) | ||||||
:param default: a default step for events that did not match any condition in choice_array. If not set, elements | ||||||
that don't match any condition will be discarded. | ||||||
:type default: Flow | ||||||
:param name: Name of this step, as it should appear in logs. Defaults to class name (Choice). | ||||||
:type name: string | ||||||
:param full_event: Whether user functions should receive and return Event objects (when True), | ||||||
or only the payload (when False). Defaults to False. | ||||||
:type full_event: boolean | ||||||
""" | ||||||
Redirects each input element into any number of predetermined downstream steps. Override select_outlets() | ||||||
to route events to any number of downstream steps. | ||||||
""" | ||||||
|
||||||
def __init__(self, choice_array, default=None, **kwargs): | ||||||
Flow.__init__(self, **kwargs) | ||||||
|
||||||
self._choice_array = choice_array | ||||||
for outlet, _ in choice_array: | ||||||
self.to(outlet) | ||||||
|
||||||
if default: | ||||||
self.to(default) | ||||||
self._default = default | ||||||
def _init(self, **kwargs): | ||||||
super()._init() | ||||||
self._name_to_outlet = {} | ||||||
for outlet in self._outlets: | ||||||
if outlet.name in self._name_to_outlet: | ||||||
raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step") | ||||||
self._name_to_outlet[outlet.name] = outlet | ||||||
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget | ||||||
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] | ||||||
|
||||||
def select_outlets(self, event): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a type hint and also in the doc string mention that the selected outlets should be returned in a list of strings - the outlet names.
Suggested change
|
||||||
""" | ||||||
Override this method to route events based on a customer logic. The default implementation will route all | ||||||
events to all outlets. | ||||||
""" | ||||||
return list(self._name_to_outlet.keys()) | ||||||
|
||||||
async def _do(self, event): | ||||||
if not self._outlets or event is _termination_obj: | ||||||
return await super()._do_downstream(event) | ||||||
chosen_outlet = None | ||||||
element = self._get_event_or_body(event) | ||||||
for outlet, condition in self._choice_array: | ||||||
if condition(element): | ||||||
chosen_outlet = outlet | ||||||
break | ||||||
if chosen_outlet: | ||||||
await chosen_outlet._do(event) | ||||||
elif self._default: | ||||||
await self._default._do(event) | ||||||
if event is _termination_obj: | ||||||
return await self._do_downstream(_termination_obj) | ||||||
else: | ||||||
event_body = event if self._full_event else event.body | ||||||
outlet_names = self.select_outlets(event_body) | ||||||
outlets = [] | ||||||
if self._passthrough_for_preview: | ||||||
outlet = self._name_to_outlet["dataframe"] | ||||||
outlets.append(outlet) | ||||||
Comment on lines
+380
to
+382
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a TODO here as well to remember to delete this "hack" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really want to duplicate the TODO that's already on the attribute definition. |
||||||
else: | ||||||
for outlet_name in outlet_names: | ||||||
if outlet_name not in self._name_to_outlet: | ||||||
raise ValueError( | ||||||
f"select_outlets() returned outlet name '{outlet_name}', which is not one of the " | ||||||
f"defined outlets: " + ", ".join(self._name_to_outlet) | ||||||
) | ||||||
outlet = self._name_to_outlet[outlet_name] | ||||||
outlets.append(outlet) | ||||||
return await self._do_downstream(event, outlets=outlets) | ||||||
|
||||||
|
||||||
class Recover(Flow): | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure it is not an error: before the change the
kwargs
were sent at theFlow
init so I would expect to seesuper().__init__(**kwargs)
unless it is redundant. I think theoutlets
param should be passed here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the
**kwargs
need to be removed. There's no kwargs in_init()
. Note that_init
is unrelated to__init__
.