-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite Choice step to make it usable from mlrun #537
Merged
+82
−62
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
4e6645d
Rewrite Choice step to make it usable from mlrun
gtopper 59fa81a
Add missing space
gtopper 56da4e9
Hack to avoid issue with mlrun preview
gtopper 207073d
Improve docs
gtopper 81163ee
Remove accidental kwargs, add type annotation
gtopper 7a544ce
Merge remote-tracking branch 'mlrun/development' into ML-7818
gtopper File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -252,15 +252,16 @@ def _event_string(event): | |
def _should_terminate(self): | ||
return self._termination_received == len(self._inlets) | ||
|
||
async def _do_downstream(self, event): | ||
if not self._outlets: | ||
async def _do_downstream(self, event, outlets=None): | ||
outlets = self._outlets if outlets is None else outlets | ||
if not outlets: | ||
return | ||
if event is _termination_obj: | ||
# Only propagate the termination object once we received one per inlet | ||
self._outlets[0]._termination_received += 1 | ||
if self._outlets[0]._should_terminate(): | ||
self._termination_result = await self._outlets[0]._do(_termination_obj) | ||
for outlet in self._outlets[1:] + self._get_recovery_steps(): | ||
outlets[0]._termination_received += 1 | ||
if outlets[0]._should_terminate(): | ||
self._termination_result = await outlets[0]._do(_termination_obj) | ||
for outlet in outlets[1:] + self._get_recovery_steps(): | ||
outlet._termination_received += 1 | ||
if outlet._should_terminate(): | ||
self._termination_result = self._termination_result_fn( | ||
|
@@ -269,28 +270,28 @@ async def _do_downstream(self, event): | |
return self._termination_result | ||
# If there is more than one outlet, allow concurrent execution. | ||
tasks = [] | ||
if len(self._outlets) > 1: | ||
if len(outlets) > 1: | ||
awaitable_result = event._awaitable_result | ||
event._awaitable_result = None | ||
original_events = getattr(event, "_original_events", None) | ||
# Temporarily delete self-reference to avoid deepcopy getting stuck in an infinite loop | ||
event._original_events = None | ||
for i in range(1, len(self._outlets)): | ||
for i in range(1, len(outlets)): | ||
event_copy = copy.deepcopy(event) | ||
event_copy._awaitable_result = awaitable_result | ||
event_copy._original_events = original_events | ||
tasks.append(asyncio.get_running_loop().create_task(self._outlets[i]._do_and_recover(event_copy))) | ||
tasks.append(asyncio.get_running_loop().create_task(outlets[i]._do_and_recover(event_copy))) | ||
# Set self-reference back after deepcopy | ||
event._original_events = original_events | ||
event._awaitable_result = awaitable_result | ||
if self.verbose and self.logger: | ||
step_name = self.name | ||
event_string = self._event_string(event) | ||
self.logger.debug(f"{step_name} -> {self._outlets[0].name} | {event_string}") | ||
await self._outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet. | ||
self.logger.debug(f"{step_name} -> {outlets[0].name} | {event_string}") | ||
await outlets[0]._do_and_recover(event) # Optimization - avoids creating a task for the first outlet. | ||
for i, task in enumerate(tasks, start=1): | ||
if self.verbose and self.logger: | ||
self.logger.debug(f"{step_name} -> {self._outlets[i].name} | {event_string}") | ||
self.logger.debug(f"{step_name} -> {outlets[i].name} | {event_string}") | ||
await task | ||
|
||
def _get_event_or_body(self, event): | ||
|
@@ -347,46 +348,48 @@ def _get_uuid(self): | |
|
||
|
||
class Choice(Flow): | ||
"""Redirects each input element into at most one of multiple downstreams. | ||
|
||
:param choice_array: a list of (downstream, condition) tuples, where downstream is a step and condition is a | ||
function. The first condition in the list to evaluate as true for an input element causes that element to | ||
be redirected to that downstream step. | ||
:type choice_array: tuple of (Flow, Function (Event=>boolean)) | ||
:param default: a default step for events that did not match any condition in choice_array. If not set, elements | ||
that don't match any condition will be discarded. | ||
:type default: Flow | ||
:param name: Name of this step, as it should appear in logs. Defaults to class name (Choice). | ||
:type name: string | ||
:param full_event: Whether user functions should receive and return Event objects (when True), | ||
or only the payload (when False). Defaults to False. | ||
:type full_event: boolean | ||
""" | ||
Redirects each input element into any number of predetermined downstream steps. Override select_outlets() | ||
to route events to any number of downstream steps. | ||
""" | ||
|
||
def __init__(self, choice_array, default=None, **kwargs): | ||
Flow.__init__(self, **kwargs) | ||
|
||
self._choice_array = choice_array | ||
for outlet, _ in choice_array: | ||
self.to(outlet) | ||
|
||
if default: | ||
self.to(default) | ||
self._default = default | ||
def _init(self): | ||
super()._init() | ||
self._name_to_outlet = {} | ||
for outlet in self._outlets: | ||
if outlet.name in self._name_to_outlet: | ||
raise ValueError(f"Ambiguous outlet name '{outlet.name}' in Choice step") | ||
self._name_to_outlet[outlet.name] = outlet | ||
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget | ||
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] | ||
|
||
def select_outlets(self, event) -> List[str]: | ||
""" | ||
Override this method to route events based on a customer logic. The default implementation will route all | ||
events to all outlets. | ||
""" | ||
return list(self._name_to_outlet.keys()) | ||
|
||
async def _do(self, event): | ||
if not self._outlets or event is _termination_obj: | ||
return await super()._do_downstream(event) | ||
chosen_outlet = None | ||
element = self._get_event_or_body(event) | ||
for outlet, condition in self._choice_array: | ||
if condition(element): | ||
chosen_outlet = outlet | ||
break | ||
if chosen_outlet: | ||
await chosen_outlet._do(event) | ||
elif self._default: | ||
await self._default._do(event) | ||
if event is _termination_obj: | ||
return await self._do_downstream(_termination_obj) | ||
else: | ||
event_body = event if self._full_event else event.body | ||
outlet_names = self.select_outlets(event_body) | ||
outlets = [] | ||
if self._passthrough_for_preview: | ||
outlet = self._name_to_outlet["dataframe"] | ||
outlets.append(outlet) | ||
Comment on lines
+380
to
+382
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a TODO here as well to remember to delete this "hack" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really want to duplicate the TODO that's already on the attribute definition. |
||
else: | ||
for outlet_name in outlet_names: | ||
if outlet_name not in self._name_to_outlet: | ||
raise ValueError( | ||
f"select_outlets() returned outlet name '{outlet_name}', which is not one of the " | ||
f"defined outlets: " + ", ".join(self._name_to_outlet) | ||
) | ||
outlet = self._name_to_outlet[outlet_name] | ||
outlets.append(outlet) | ||
return await self._do_downstream(event, outlets=outlets) | ||
|
||
|
||
class Recover(Flow): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure it is not an error: before the change the
kwargs
were sent at theFlow
init so I would expect to seesuper().__init__(**kwargs)
unless it is redundant. I think theoutlets
param should be passed here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the
**kwargs
need to be removed. There's no kwargs in_init()
. Note that_init
is unrelated to__init__
.