Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flytekit: Rename map_task to map, replace min_successes and min_success_ratio with tolerance, rename max_parallelism to concurrency #3107

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@

import os
import sys
import warnings
from typing import Generator

from rich import traceback
Expand All @@ -216,10 +217,9 @@
else:
from importlib.metadata import entry_points


from flytekit._version import __version__
from flytekit.configuration import Config
from flytekit.core.array_node_map_task import map_task
from flytekit.core.array_node_map_task import map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider maintaining backward compatibility for imports

Consider keeping both map_task and map imports to maintain backward compatibility. The alias is defined later but importing directly as map may break existing code that uses map_task.

Code suggestion
Check the AI-generated fix before applying
Suggested change
from flytekit.core.array_node_map_task import map
from flytekit.core.array_node_map_task import map_task

Code Review Run #d47fe6


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

from flytekit.core.artifact import Artifact
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes
Expand Down Expand Up @@ -265,6 +265,13 @@
StructuredDatasetType,
)

warnings.warn(
"'map_task' is deprecated and will be removed in a future version. Use 'map' instead.",
DeprecationWarning,
stacklevel=2,
)
map_task = map


def current_context() -> ExecutionParameters:
"""
Expand Down
16 changes: 13 additions & 3 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,26 @@ class RunLevelParams(PyFlyteParams):
),
)
)
max_parallelism: int = make_click_option_field(

concurrency: int = make_click_option_field(
click.Option(
param_decls=["--max-parallelism"],
param_decls=["--concurrency"],
required=False,
type=int,
show_default=True,
help="Number of nodes of a workflow that can be executed in parallel. If not specified,"
" project/domain defaults are used. If 0 then it is unlimited.",
)
)
max_parallelism: int = make_click_option_field(
click.Option(
param_decls=["--max-parallelism"],
required=False,
type=int,
show_default=True,
help="[Deprecated] Use --concurrency instead",
)
)
disable_notifications: bool = make_click_option_field(
click.Option(
param_decls=["--disable-notifications"],
Expand Down Expand Up @@ -516,7 +526,7 @@ def options_from_run_params(run_level_params: RunLevelParams) -> Options:
raw_output_data_config=RawOutputDataConfig(output_location_prefix=run_level_params.raw_output_data_prefix)
if run_level_params.raw_output_data_prefix
else None,
max_parallelism=run_level_params.max_parallelism,
concurrency=run_level_params.max_parallelism,
disable_notifications=run_level_params.disable_notifications,
security_context=security.SecurityContext(
run_as=security.Identity(k8s_service_account=run_level_params.service_account)
Expand Down
48 changes: 39 additions & 9 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import math
import os # TODO: use flytekit logger
import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast

Expand Down Expand Up @@ -369,11 +370,12 @@ def _raw_execute(self, **kwargs) -> Any:
return outputs


def map_task(
def map(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider keeping descriptive function name

Consider keeping the original function name map_task instead of renaming to map as it could conflict with Python's built-in map function and cause confusion. The original name was more descriptive of the function's purpose.

Code suggestion
Check the AI-generated fix before applying
Suggested change
def map(
def map_task(

Code Review Run #d47fe6


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
tolerance: Optional[Union[float, int]] = None,
min_successes: Optional[int] = None, # Deprecated
min_success_ratio: Optional[float] = None, # Deprecated
**kwargs,
):
"""
Expand All @@ -385,23 +387,51 @@ def map_task(
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the
array node will inherit parallelism from the workflow
:param min_successes: The minimum number of successful executions
:param min_success_ratio: The minimum ratio of successful executions
:param tolerance: Failure tolerance threshold.
If float (0-1): represents minimum success ratio
If int (>1): represents minimum number of successes
:param min_successes: The minimum number of successful executions [Deprecated] Use tolerance instead
:param min_success_ratio: The minimum ratio of successful executions [Deprecated] Use tolerance instead
"""
from flytekit.remote import FlyteLaunchPlan

if min_successes is not None and min_success_ratio != 1.0:
warnings.warn(
"min_success and min_success_ratio are deprecated. Please use 'tolerance' parameter instead",
DeprecationWarning,
stacklevel=2,
)

computed_min_ratio = 1.0
computed_min_success = None

if tolerance is not None:
if isinstance(tolerance, float):
if not 0 <= tolerance <= 1:
raise ValueError("tolerance must be between 0 and 1")
computed_min_ratio = tolerance
elif isinstance(tolerance, int):
if tolerance < 1:
raise ValueError("tolerance must be greater than 0")
computed_min_success = tolerance
else:
raise TypeError("tolerance must be float or int")

final_min_ratio = computed_min_ratio if min_success_ratio is None else min_success_ratio
final_min_successes = computed_min_success if min_successes is None else min_successes

if isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)):
return array_node(
target=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
min_successes=final_min_successes,
min_success_ratio=final_min_ratio,
)
return array_node_map_task(
task_function=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
min_successes=final_min_successes,
min_success_ratio=final_min_ratio,
**kwargs,
)

Expand Down
55 changes: 43 additions & 12 deletions flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing
import warnings
from typing import Any, Callable, Dict, List, Optional, Type

from flytekit.core import workflow as _annotated_workflow
Expand Down Expand Up @@ -129,7 +130,8 @@ def create(
labels: Optional[_common_models.Labels] = None,
annotations: Optional[_common_models.Annotations] = None,
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
max_parallelism: Optional[int] = None, # Deprecated: Use concurrency instead
concurrency: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
Expand Down Expand Up @@ -183,7 +185,8 @@ def create(
labels=labels,
annotations=annotations,
raw_output_data_config=raw_output_data_config,
max_parallelism=max_parallelism,
concurrency=concurrency, # Pass new parameter
max_parallelism=max_parallelism, # Pass deprecated parameter
security_context=security_context,
trigger=trigger,
overwrite_cache=overwrite_cache,
Expand Down Expand Up @@ -213,7 +216,8 @@ def get_or_create(
labels: Optional[_common_models.Labels] = None,
annotations: Optional[_common_models.Annotations] = None,
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
concurrency: Optional[int] = None,
max_parallelism: Optional[int] = None, # Deprecated
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
Expand Down Expand Up @@ -241,9 +245,10 @@ def get_or_create(
:param annotations: Optional annotations to attach to executions created by this launch plan.
:param raw_output_data_config: Optional location of offloaded data for things like S3, etc.
:param auth_role: Add an auth role if necessary.
:param max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire
workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and
parallelism/concurrency of MapTasks is independent from this.
:param concurrency: Controls the maximum number of tasknodes that can be run in parallel for the entire
workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and
parallelism/concurrency of MapTasks is independent from this.
:param max_parallelism: [Deprecated] Use concurrency instead.
:param trigger: [alpha] This is a new syntax for specifying schedules.
:param overwrite_cache: If set to True, the execution will always overwrite cache
:param auto_activate: If set to True, the launch plan will be activated automatically on registration.
Expand All @@ -258,6 +263,7 @@ def get_or_create(
or annotations is not None
or raw_output_data_config is not None
or auth_role is not None
or concurrency is not None
or max_parallelism is not None
or security_context is not None
or trigger is not None
Expand Down Expand Up @@ -296,7 +302,11 @@ def get_or_create(
("labels", labels, cached_outputs["_labels"]),
("annotations", annotations, cached_outputs["_annotations"]),
("raw_output_data_config", raw_output_data_config, cached_outputs["_raw_output_data_config"]),
("max_parallelism", max_parallelism, cached_outputs["_max_parallelism"]),
(
"concurrency",
concurrency if concurrency is not None else max_parallelism,
cached_outputs.get("_concurrency", cached_outputs.get("")),
),
("security_context", security_context, cached_outputs["_security_context"]),
("overwrite_cache", overwrite_cache, cached_outputs["_overwrite_cache"]),
("auto_activate", auto_activate, cached_outputs["_auto_activate"]),
Expand Down Expand Up @@ -326,7 +336,8 @@ def get_or_create(
labels,
annotations,
raw_output_data_config,
max_parallelism,
concurrency=concurrency,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor init method signature

The 'init' method has too many parameters (14 > 5) and is missing docstring and return type annotation.

Code suggestion
Check the AI-generated fix before applying
 -    def __init__(
 -        self,
 -        name: str,
 -        workflow: _annotated_workflow.WorkflowBase,
 -        parameters: _interface_models.ParameterMap,
 -        fixed_inputs: _literal_models.LiteralMap,
 -        schedule: Optional[_schedule_model.Schedule] = None,
 -        notifications: Optional[List[_common_models.Notification]] = None,
 -        labels: Optional[_common_models.Labels] = None,
 -        annotations: Optional[_common_models.Annotations] = None,
 -        raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
 -        max_parallelism: Optional[int] = None,
 -        security_context: Optional[security.SecurityContext] = None,
 -        trigger: Optional[LaunchPlanTriggerBase] = None,
 -        overwrite_cache: Optional[bool] = None,
 -        auto_activate: bool = False,
 -    ):
 +    @dataclass
 +    class Config:
 +        """Configuration for LaunchPlan initialization."""
 +        name: str
 +        workflow: _annotated_workflow.WorkflowBase
 +        parameters: _interface_models.ParameterMap
 +        fixed_inputs: _literal_models.LiteralMap
 +        schedule: _schedule_model.Schedule | None = None
 +        notifications: list[_common_models.Notification] | None = None
 +        labels: _common_models.Labels | None = None
 +        annotations: _common_models.Annotations | None = None
 +        raw_output_data_config: _common_models.RawOutputDataConfig | None = None
 +        max_parallelism: int | None = None
 +        security_context: security.SecurityContext | None = None
 +        trigger: LaunchPlanTriggerBase | None = None
 +        overwrite_cache: bool | None = None
 +        auto_activate: bool = False
 +
 +    def __init__(self, config: Config) -> None:

Code Review Run #99b31d


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

max_parallelism=max_parallelism,
auth_role=auth_role,
security_context=security_context,
trigger=trigger,
Expand All @@ -347,7 +358,8 @@ def __init__(
labels: Optional[_common_models.Labels] = None,
annotations: Optional[_common_models.Annotations] = None,
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
concurrency: Optional[int] = None,
max_parallelism: Optional[int] = None, # Deprecated
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
Expand All @@ -367,7 +379,14 @@ def __init__(
self._labels = labels
self._annotations = annotations
self._raw_output_data_config = raw_output_data_config
self._max_parallelism = max_parallelism
self._concurrency = concurrency
self._max_parallelism = concurrency if concurrency is not None else max_parallelism
if max_parallelism is not None:
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
self._security_context = security_context
self._trigger = trigger
self._overwrite_cache = overwrite_cache
Expand All @@ -385,7 +404,8 @@ def clone_with(
labels: Optional[_common_models.Labels] = None,
annotations: Optional[_common_models.Annotations] = None,
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
concurrency: Optional[int] = None,
max_parallelism: Optional[int] = None, # Dreprecated
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
Expand All @@ -401,6 +421,7 @@ def clone_with(
labels=labels or self.labels,
annotations=annotations or self.annotations,
raw_output_data_config=raw_output_data_config or self.raw_output_data_config,
concurrency=concurrency or self.concurrency,
max_parallelism=max_parallelism or self.max_parallelism,
security_context=security_context or self.security_context,
trigger=trigger,
Expand Down Expand Up @@ -466,7 +487,17 @@ def raw_output_data_config(self) -> Optional[_common_models.RawOutputDataConfig]

@property
def max_parallelism(self) -> Optional[int]:
return self._max_parallelism
"""[Deprecated] Use concurrency instead. This property is maintained for backward compatibility"""
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
return self._concurrency

@property
def concurrency(self) -> Optional[int]:
return self._concurrency

@property
def security_context(self) -> Optional[security.SecurityContext]:
Expand Down
55 changes: 42 additions & 13 deletions flytekit/core/options.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
import warnings
from dataclasses import dataclass

from flytekit.models import common as common_models
Expand All @@ -8,33 +9,61 @@
@dataclass
class Options(object):
"""
These are options that can be configured for a launchplan during registration or overridden during an execution.
For instance two people may want to run the same workflow but have the offloaded data stored in two different
These are options that can be configured for a launch plan during registration or overridden during an execution.
For instance, two people may want to run the same workflow but have the offloaded data stored in two different
buckets. Or you may want labels or annotations to be different. This object is used when launching an execution
in a Flyte backend, and also when registering launch plans.

Args:
labels: Custom labels to be applied to the execution resource
annotations: Custom annotations to be applied to the execution resource
security_context: Indicates security context for permissions triggered with this launch plan
raw_output_data_config: Optional location of offloaded data for things like S3, etc.
remote prefix for storage location of the form ``s3://<bucket>/key...`` or
``gcs://...`` or ``file://...``. If not specified will use the platform configured default. This is where
Attributes:
labels (typing.Optional[common_models.Labels]): Custom labels to be applied to the execution resource.
annotations (typing.Optional[common_models.Annotations]): Custom annotations to be applied to the execution resource.
raw_output_data_config (typing.Optional[common_models.RawOutputDataConfig]): Optional location of offloaded data
for things like S3, etc. Remote prefix for storage location of the form ``s3://<bucket>/key...`` or
``gcs://...`` or ``file://...``. If not specified, will use the platform-configured default. This is where
the data for offloaded types is stored.
max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow.
notifications: List of notifications for this execution.
disable_notifications: This should be set to true if all notifications are intended to be disabled for this execution.
security_context (typing.Optional[security.SecurityContext]): Indicates security context for permissions triggered
with this launch plan.
concurrency (typing.Optional[int]): Controls the maximum number of task nodes that can be run in parallel for the
entire workflow.
notifications (typing.Optional[typing.List[common_models.Notification]]): List of notifications for this execution.
disable_notifications (typing.Optional[bool]): Set to True if all notifications are intended to be disabled
for this execution.
overwrite_cache (typing.Optional[bool]): When set to True, forces the execution to overwrite any existing cached values.
"""

labels: typing.Optional[common_models.Labels] = None
annotations: typing.Optional[common_models.Annotations] = None
raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None
security_context: typing.Optional[security.SecurityContext] = None
max_parallelism: typing.Optional[int] = None
concurrency: typing.Optional[int] = None
notifications: typing.Optional[typing.List[common_models.Notification]] = None
disable_notifications: typing.Optional[bool] = None
overwrite_cache: typing.Optional[bool] = None

@property
def max_parallelism(self) -> typing.Optional[int]:
"""
[Deprecated] Use concurrency instead. This property is maintained for backward compatibility
"""
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
return self.concurrency

@max_parallelism.setter
def max_parallelism(self, value: typing.Optional[int]):
"""
Setter for max_parallelism (deprecated in favor of concurrency)
"""
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
self.concurrency = value

@classmethod
def default_from(
cls,
Expand Down
Loading