Skip to content

Commit

Permalink
Add modal.FunctionCall.gather to replace modal.functions.gather (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom authored Feb 25, 2025
1 parent cf0432e commit f83977d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
43 changes: 26 additions & 17 deletions modal/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,26 +1597,35 @@ async def from_id(
fc._is_generator = is_generator
return fc

@staticmethod
async def gather(*function_calls: "_FunctionCall[Any]") -> list[Any]:
"""Wait until all Modal FunctionCall objects have results before returning.
async def _gather(*function_calls: _FunctionCall[ReturnType]) -> typing.Sequence[ReturnType]:
"""Wait until all Modal function calls have results before returning
Accepts a variable number of `FunctionCall` objects, as returned by `Function.spawn()`.
Accepts a variable number of FunctionCall objects as returned by `Function.spawn()`.
Returns a list of results from each FunctionCall, or raises an exception
from the first failing function call.
Returns a list of results from each function call, or raises an exception
of the first failing function call.
Examples:
```python notest
fc1 = slow_func_1.spawn()
fc2 = slow_func_2.spawn()
E.g.
result_1, result_2 = modal.FunctionCall.gather(fc1, fc2)
```
"""
try:
return await TaskContext.gather(*[fc.get() for fc in function_calls])
except Exception as exc:
# TODO: kill all running function calls
raise exc

```python notest
function_call_1 = slow_func_1.spawn()
function_call_2 = slow_func_2.spawn()

result_1, result_2 = gather(function_call_1, function_call_2)
```
"""
try:
return await TaskContext.gather(*[fc.get() for fc in function_calls])
except Exception as exc:
# TODO: kill all running function calls
raise exc
async def _gather(*function_calls: _FunctionCall[ReturnType]) -> typing.Sequence[ReturnType]:
"""Deprecated: Please use `modal.FunctionCall.gather()` instead."""
deprecation_warning(
(2025, 2, 24),
"`modal.functions.gather()` is deprecated; please use `modal.FunctionCall.gather()` instead.",
)
return await _FunctionCall.gather(*function_calls)
4 changes: 2 additions & 2 deletions test/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from modal._utils.async_utils import synchronize_api
from modal._vendor import cloudpickle
from modal.exception import DeprecationError, ExecutionError, InvalidError
from modal.functions import Function, FunctionCall, gather
from modal.functions import Function, FunctionCall
from modal.runner import deploy_app
from modal_proto import api_pb2
from test.helpers import deploy_app_externally
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_sync_parallelism(client, servicer):
with app.run(client=client):
t0 = time.time()
# NOTE tests breaks in macOS CI if the smaller time is smaller than ~300ms
res = gather(slo1_modal.spawn(0.31), slo1_modal.spawn(0.3))
res = FunctionCall.gather(slo1_modal.spawn(0.31), slo1_modal.spawn(0.3))
t1 = time.time()
assert res == [0.31, 0.3] # results should be ordered as inputs, not by completion time
assert t1 - t0 < 0.6 # less than the combined runtime, make sure they run in parallel
Expand Down
2 changes: 1 addition & 1 deletion test/supports/type_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def other_func() -> str:
ret = typed_func.remote(a="hello")
assert_type(ret, float)

ret2 = modal.functions.gather(typed_func.spawn("bar"), other_func.spawn())
ret2 = modal.FunctionCall.gather(typed_func.spawn("bar"), other_func.spawn())
# This assertion doesn't work in mypy (it infers the more generic list[object]), but does work in pyright/vscode:
# assert_type(ret2, typing.List[typing.Union[float, str]])
mypy_compatible_ret: typing.Sequence[object] = ret2 # mypy infers to the broader "object" type instead
Expand Down

0 comments on commit f83977d

Please sign in to comment.