Skip to content

Commit

Permalink
Fix typing of task decorator for retry_condition_fn argument (#16621)
Browse files Browse the repository at this point in the history
Co-authored-by: nate nowack <[email protected]>
  • Loading branch information
peterbygrave and zzstoatzz authored Jan 17, 2025
1 parent 16e85ce commit 07b1cab
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
38 changes: 37 additions & 1 deletion src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,43 @@ def task(
refresh_cache: Optional[bool] = None,
on_completion: Optional[list[StateHookCallable]] = None,
on_failure: Optional[list[StateHookCallable]] = None,
retry_condition_fn: Optional[Callable[[Task[P, Any], TaskRun, State], bool]] = None,
retry_condition_fn: Literal[None] = None,
viz_return_value: Any = None,
) -> Callable[[Callable[P, R]], Task[P, R]]:
...


# see https://github.com/PrefectHQ/prefect/issues/16380
@overload
def task(
__fn: Literal[None] = None,
*,
name: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[Iterable[str]] = None,
version: Optional[str] = None,
cache_policy: Union[CachePolicy, type[NotSet]] = NotSet,
cache_key_fn: Optional[
Callable[["TaskRunContext", dict[str, Any]], Optional[str]]
] = None,
cache_expiration: Optional[datetime.timedelta] = None,
task_run_name: Optional[TaskRunNameValueOrCallable] = None,
retries: int = 0,
retry_delay_seconds: Union[
float, int, list[float], Callable[[int], list[float]], None
] = None,
retry_jitter_factor: Optional[float] = None,
persist_result: Optional[bool] = None,
result_storage: Optional[ResultStorage] = None,
result_storage_key: Optional[str] = None,
result_serializer: Optional[ResultSerializer] = None,
cache_result_in_memory: bool = True,
timeout_seconds: Union[int, float, None] = None,
log_prints: Optional[bool] = None,
refresh_cache: Optional[bool] = None,
on_completion: Optional[list[StateHookCallable]] = None,
on_failure: Optional[list[StateHookCallable]] = None,
retry_condition_fn: Optional[Callable[[Task[P, R], TaskRun, State], bool]] = None,
viz_return_value: Any = None,
) -> Callable[[Callable[P, R]], Task[P, R]]:
...
Expand Down
59 changes: 59 additions & 0 deletions tests/typesafety/test_tasks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# yaml-language-server: $schema=https://raw.githubusercontent.com/typeddjango/pytest-mypy-plugins/master/pytest_mypy_plugins/schema.json
- case: prefect_task_decorator_no_args
main: |
from prefect import task
@task
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:5: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""

- case: prefect_task_decorator_call_with_no_args
main: |
from prefect import task
@task()
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:5: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""

- case: prefect_task_decorator_with_name_arg
main: |
from prefect import task
@task(name="bar")
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:5: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""

- case: prefect_task_decorator_with_retry_condition_fn_as_none_arg
main: |
from prefect.tasks import task
@task(retry_condition_fn=None)
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:5: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""

- case: prefect_task_decorator_with_retry_condition_fn_arg
main: |
from prefect.tasks import P, R, Task, task
from prefect.client.schemas import TaskRun
from prefect.states import State
def retry_condition_fn(task: Task[P, R], task_run: TaskRun, state: State) -> bool:
return False
@task(retry_condition_fn=retry_condition_fn)
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:9: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""

0 comments on commit 07b1cab

Please sign in to comment.